当前位置: 首页 > news >正文

【笔记 Pytorch 08】深度学习模板 (未完)

文章目录

  • 一、声明
  • 二、工程结构
  • 三、文件内容
      • main.py
      • model.py
      • dataset.py
      • utils.py
  • 四、问题汇总

一、声明

非常感谢这些资料的作者:
【参考1】、【PyTorch速成教程 (by Sung Kim)】

二、工程结构

├── main.py:实现训练 (train) 、验证(validation)和测试(test)
│ ├── model.py:实现的模型
│ ├── dataset.py:加载的数据
│ ├── utils.py:常用功能

三、文件内容

main.py

from torch.utils.data import Dataset, DataLoader
from torch import from_numpy, tensor
from torch.autograd import Variable
import numpy as np
import model
import utils# load data
dataset = MyDataset()
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)# model
model=Model()# define loss and optimizer
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)# train
for epoch in range(2):for i, data in enumerate(train_loader, 0):# get the inputsinputs, labels = data# wrap them in Variableinputs, labels = Variable(inputs), Variable(labels)# Forward passy_pred=model(inputs)# Compute and print lossloss=criterion(y_pred,labels)accuracy= ultis.accuracy(y_pred,labels)print("[{:05d}/{:05d}] train_loss:{:.4f} accuracy: {:.4f}]".format(i,epoch,loss.data[0],accuracy))# updateoptimizer.zero_grad()	# zero gradientsloss.backward()			# perform a backward passoptimizer.step() 		# update weight or parameters

model.py

import torch
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.l1=torch.nn.Linear(8,6)self.l2=torch.nn.Linear(6,4)self.l3=torch.nn.Linear(4,1)self.sigmoid=torch.nn.Sigmoid()# 数据流def forward(self,x):out1=self.sigmoid(self.l1(x))out2=self.sigmoid(self.l2(out1))y_pred=self.sigmoid(self.l3(out2))return y_pred

dataset.py

要点:
(1)必须重载 __getitem____len__
(2)

import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):def __init__(self):  # Initialize your data, download, etc.xy = np.loadtxt('./data/diabetes.csv.gz', delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.x_data = torch.from_numpy(xy[:, 0:-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.len

utils.py

import numpy as np
import scipy.sparse as sp
import torch
import osdef encode_onehot(labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)def list_all_files(rootdir):_files = []#列出文件夹下所有的目录与文件list_file = os.listdir(rootdir)for i in range(0,len(list_file)):# 构造路径path = os.path.join(rootdir,list_file[i])# 判断路径是否是一个文件目录或者文件# 如果是文件目录,继续递归        if os.path.isdir(path):_files.extend(list_all_files(path))if os.path.isfile(path):_files.append(path)return _filesdef mkdir(path):# 去除首位空格path=path.strip()# 去除尾部 \ 符号path=path.rstrip("\\")# 判断路径是否存在# 存在     True# 不存在   FalseisExists=os.path.exists(path)# 判断结果if not isExists:# 如果不存在则创建目录# 创建目录操作函数os.makedirs(path) print(path+' create sucess')return Trueelse:# 如果目录存在则不创建,并提示目录已存在print(path+' path exist !')return False

四、问题汇总

:dataset.py中__getitem__返回的是一个元素,还是一个batch数据?

http://www.lryc.cn/news/242391.html

相关文章:

  • 【如何学习Python自动化测试】—— Cookie 处理
  • IOS+Appium+Python自动化全实战教程
  • 华硕灵耀XPro(UX7602ZM)原装Win11系统恢复安装教程方法
  • SpringBoot整合Redis,redis连接池和RedisTemplate序列化
  • 学习课题:逐步构建开发播放器【QT5 + FFmpeg6 + SDL2】
  • Linux 6.7全面改进x86 CPU微码加载方式
  • 【Python】Fastapi swagger-ui.css 、swagger-ui-bundle.js 无法加载,docs无法加载,redocs无法使用
  • 算法-中等-链表-两数相加
  • STC单片机选择外部晶振烧录程序无法切换回内部晶振导致单片机不能使用
  • 使用STM32+SPI Flash模拟U盘
  • 【自主探索】基于 frontier_exploration 的单个机器人自主探索建图
  • 模板初阶(1):函数模板,类模板
  • AIGC: 关于ChatGPT中生成输出表格/表情/图片/图表这些非文本的方式
  • gen_arrow_contour_xld
  • 智能时代的智能工具(gpt)国产化助手
  • 量子计算 | 解密著名量子算法Shor算法和Grover算法
  • 缓存组件状态,提升用户体验:探索 keep-alive 的神奇世界
  • 万字长文 - Python 日志记录器logging 百科全书 - 高级配置之 日志文件配置
  • ​LeetCode解法汇总1410. HTML 实体解析器
  • OpenGL 绘制旋转球(Qt)
  • 解决:javax.websocket.server.ServerContainer not available 报错问题
  • 81基于matlab GUI的图像处理
  • 虚拟机系列:vmware和Oracle VM VirtualBox虚拟机的区别,简述哪一个更适合我?以及相互转换
  • Go lumberjack 日志轮换和管理
  • git常用命令(git github ssh)
  • 完美解决:Nginx访问PHP出现File not found.
  • 音视频5、libavformat-2
  • python opencv -模板匹配
  • 大数据技能大赛(高职组)答案
  • C++动态规划算法:最多可以参加的会议数目