当前位置: 首页 > news >正文

01 Pytorch 基础

paddle不需要放数据到gpu!

区别:1.batch_norlization 不同

            2. 

1.数据处理

1.取一个数据,以及计算大小

        (剩下的工作,取batch,pytorch会自动做好了)

2.模型相关 

如何得到结果

3.模型训练/模型验证: 

代码剖析 

1.配置文件yaml (字典)
#参数配置config = {"train_path":'/kaggle/input/deepshare-playground/train_behaviour.csv',"test_path":'/kaggle/input/deepshare-playground/test_behaviour.csv',"debug_mode" : False,"epoch" : 20,"batch" : 2048,"lr" : 0.001,"device" : 0,
}

使用: config[ '名称' ]

train_df = pd.read_csv(config['train_path'])
if config['debug_mode']:train_df = train_df[:1000]
test_df = pd.read_csv(config['test_path'])
 2.处理数据:定义DataSet

关键:len + getitem(获取单独的一个)

#Dataset构造
class BaseDataset(Dataset):def __init__(self,df):self.df = dfself.feature_name = ['user_id','item_id']#数据编码self.enc_data()def enc_data(self):#使用enc_dict对数据进行编码self.enc_df = copy.deepcopy(self.df)for col in self.feature_name:self.enc_df[col] = torch.Tensor(np.array(self.df[col])).long()def __getitem__(self, index):data = dict()for col in self.feature_name:data[col] = torch.Tensor([self.enc_df[col].iloc[index]]).long().squeeze(-1)if 'label' in self.enc_df.columns:data['label'] = torch.Tensor([self.enc_df['label'].iloc[index]]).squeeze(-1)return datadef __len__(self):return len(self.df)
3.模型定义
4.训练与验证

完成Train Pipeline/Valid Pipeline

 4.1 拷贝数据->gpu

4.2前向传输

4.3

4.4 指标计算

http://www.lryc.cn/news/373927.html

相关文章:

  • STL——set、map、multiset、multimap的介绍及使用
  • 使用C语言,写一个类似Linux中执行cat命令的类似功能
  • 【Android】Android系统性学习——Android系统架构
  • 鸿蒙应用开发
  • 索引失效有效的11种情况
  • 字符数组基础知识及题目
  • 一个简单的玩具机器人代码
  • 设计模式-装饰器模式Decorator(结构型)
  • RK3588开发板中使用Qt对zip文件进行解压
  • 三、网络服务协议
  • C++初学者指南第一步---1. C++开发环境设置
  • 二维数组与指针【C语言】
  • 解决linux下安装apex库报错:ModuleNotFoundError: No module named ‘packaging‘
  • React基础教程(07):条件渲染
  • 回归预测 | Matlab实现NGO-HKELM北方苍鹰算法优化混合核极限学习机多变量回归预测
  • 操作系统——信号
  • 力扣1482.制作m束花所需的最少时间
  • 解决 Linux 和 Java 1.8 中上传中文名称图片报错问题
  • cocos开发的时候 wx.onShow在vscode里面显示红色
  • 使用 PNPM 从零搭建 Monorepo,测试组件并发布
  • Oracle 19C 数据库表被误删除的模拟恢复
  • 【CICID】GitHub-Actions语法
  • Ionic 创建 APP
  • 【数学代码】幂
  • os.system() 函数
  • Spring Boot中的RESTful API详细介绍及使用
  • nlp学习笔记
  • 使用python获取内存信息
  • 外包公司泛滥,这些常识你应该提前知道?
  • Linux下的抓包工具使用介绍