当前位置: 首页 > news >正文

模型训练时CPU和GPU大幅度波动——可能是数据的读入拖后腿

模型训练时CPU和GPU大幅度波动——可能是数据的加载拖后腿

问题

在进行猫狗大战分类任务时,发现模型训练时CPU和GPU大幅度波动,且模型训练速度很慢。

原因

​ 初步分析可能是数据加载(包括数据的transform,我用了Resize,ToTensor,Normalize这三个操作)的的速度太慢,于是通过计算一个epoch数据加载的时间来判断,最后发现数据加载的数据和一个epoch训练的时间差不太多(因为用的模型较小,是ResNet18,如果模型比较大,训练时间比数据加载时间大得多的时候,这种情况CPU和GPU的波动频率和幅度会好很多,情况最好的是,在训练一个完batch的数据前,下一个batch的数据已经准备好了)。测量加载时间代码如下:

import time
from torch.utils.data import DataLoaderdata_loader = DataLoader(dataset, batch_size=64)
start_time = time.time()# 遍历数据加载器中的所有批次
for i, data in enumerate(data_loader):passend_time = time.time()
# 计算并打印整个数据读取的时间
total_time = end_time - start_time
print(f"Total data loading time: {total_time:.4f} seconds")

然后再计算训练一个epoch的时间,若没有比加载数据的时间大很多的话,大概率就是数据加载拖后腿了。

解决方法

我使用的是方法是将所有数据一次性读入内存中,避免频繁进行磁盘IO,这样集中把所有数据读出来的时间要比一边训练一边读要快的多(使用较小的模型一般数据量不大,全部读入内存应该没什么问题,如果数据量较大呢?这时候用的模型一般也会较大,训练的时间占据主导,这时候就基本不会出现gpu等待数据的情况了)。以猫狗大战这个任务来说,自定义的Dataset如下,关键代码后用!!!..表示:

class CatDogDataset(Dataset):def __init__(self, root_dir, transform=None, test=False):self.root_dir = root_dirself.transform = transformself.image_paths = []self.image_data = []		# !!!!!!!!!!!!!!!!!!! self.labels = []self.test = testfor filename in os.listdir(root_dir):if filename.endswith('.jpg'):image_path = os.path.join(root_dir, filename)image = Image.open(image_path).convert('RGB')  # 转换为RGB格式if self.transform:image = self.transform(image)self.image_paths.append(image_path)		self.image_data.append(image)		# !!!!!!!!!!!!!!!!!!!!	将所有图片读到内存进来if not test:if 'cat' in filename:self.labels.append(0)  # cat 类别标记为 0elif 'dog' in filename:self.labels.append(1)  # dog 类别标记为 1def __len__(self):return len(self.image_data)def __getitem__(self, idx):if self.test:return self.image_data[idx], self.image_paths[idx]  # 测试集返回图像及其路径else:return self.image_data[idx], self.labels[idx]
http://www.lryc.cn/news/441484.html

相关文章:

  • keep-alive的应用场景
  • 【C++ Primer Plus习题】16.9
  • Java入门:09.Java中三大特性(封装、继承、多态)02
  • AI为云游戏带来的革新及解决方案:深度技术剖析与未来展望
  • 集合是什么
  • JavaDS —— 图
  • 魅思-视频管理系统 getOrderStatus SQL注入漏洞复现
  • SOME/IP通信协议在汽车业务具体示例
  • jupyter notebook添加环境/添加内核
  • 建模杂谈系列256 规则函数化改造
  • python实现冒泡排序的算法
  • 爱玩游戏的弟弟,被人投资了100万
  • Pandas_数据结构详解
  • Leetcode 3287. Find the Maximum Sequence Value of Array
  • python 山峦图
  • Open3D:3D数据处理与可视化的强大工具
  • YOLOv8改进系列,YOLOv8的Neck替换成AFPN(CVPR 2023)
  • BitLocker硬盘加密的详细教程分享
  • YOLOv8的GPU环境搭建方法
  • JZ2440下载后设置NAND启动文件系统
  • AI绘画与摄影新纪元:ChatGPT+Midjourney+文心一格 共绘梦幻世界
  • 金手指设计
  • Chainlit集成LlamaIndex并使用通义千问模型实现AI知识库检索网页对话应用增强版
  • 详解c++菱形继承和多态---下
  • python学习笔记目录
  • 非结构化数据中台架构设计最佳实践
  • 鹏鼎控股社招校招入职SHL综合能力测评:高分攻略及真题题库解析答疑
  • 【测向定位】差频MUSIC算法DOA估计【附MATLAB代码】
  • 智能车镜头组入门(四)元素识别
  • Java键盘输入语句