当前位置: 首页 > news >正文

杂记(3):在Pytorch中如何操作将数据集分为训练集和测试集?

在Pytorch中如何操作将数据集分为训练集和测试集?

  • 0. 前言
  • 1. 手动切分
  • 2. train_test_split方法
  • 3. Pytorch自带方法
  • 4. 总结

0. 前言

数据集需要分为训练集和测试集! 其中,训练集单纯用来训练,优化模型参数;测试集单纯用来测试,评价模型效果。然而,如何将数据集分为训练集和测试集这个简单的问题网上的回答也是五花八门,明明有简单的方法,当然不想用麻烦的方法啦!因此,这里做一下简单记录!

1. 手动切分

这里所言的手动切分指的是:将数据集前面一部分分为训练集,后面一部分分为测试集。具体代码而言如下:

# 假设所有数据极为数组a 标签为b
train_X = a[:int(0.8*len(a))]
test_X = a[int(0.8*len(a)):]train_Y = b[:int(0.8*len(a))]
test_Y = b[int(0.8*len(a)):]train_dataset= Data.TensorDataset(torch.FloatTensor(train_X ), torch.FloatTensor(train_Y ))
test_dataset= Data.TensorDataset(torch.FloatTensor(test_X), torch.FloatTensor(test_Y))trainLoader = DataLoader(dataset = train_dataset,batch_size = 18,num_workers = 0,shuffle = True)
testLoader = DataLoader(dataset = test_dataset,batch_size = 18,num_workers = 0,shuffle = True)

2. train_test_split方法

使用机器学习中的 train_test_split 方法!在机器学习中切分数据集一般都用该方法,但是在Pytorch中还是需要进行转换后方可输入模型。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(a, b, test_size=0.33, random_state=42)train_dataset= Data.TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
test_dataset= Data.TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test ))trainLoader = DataLoader(dataset = train_dataset,batch_size = 18,num_workers = 0,shuffle = True)
testLoader = DataLoader(dataset = test_dataset,batch_size = 18,num_workers = 0,shuffle = True)

3. Pytorch自带方法

Pytorch中自带的有将数据集随机切分的方法 ( torch.utils.data.random_split ),不需要额外的操作!!!!具体代码如下:

import torch.utils.data as Datadataset = Data.TensorDataset(torch.FloatTensor(a), torch.FloatTensor(b))
batch_size = 16
# 将数据集分为训练集和测试集
trainLoader, testLodaer = Data.random_split(dataset,lengths=[int(0.9 * len(dataset)),len(dataset) - int(0.9 * len(dataset))],generator=torch.Generator().manual_seed(0))

4. 总结

到此,使用 在Pytorch中如何操作将数据集分为训练集和测试集已经介绍完毕了!!! 如果有什么问题欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的方法也可以在评论区提出,后续会对其进行添加!!!!

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

http://www.lryc.cn/news/222962.html

相关文章:

  • 【MySQL篇】数据库角色
  • c++ 信奥赛编程 2050:【例5.20】字串包含
  • 用dbeaver创建一个enum类型,并讲述一部分,mysql的enum类型的知识
  • Paste v4.1.2(Mac剪切板)
  • 事件绑定-回调函数
  • Makefile 总述
  • 写给新用户-Mac软件指南篇:让你的Mac更好用
  • 03运算符综合
  • LeetCode刷题--思路总结记录
  • Nodejs
  • 【面经】spring,springboot,springcloud有什么区别和联系
  • SpringBoot Kafka消费者 多kafka配置
  • git 标签相关命令
  • 我在Vscode学OpenCV 图像运算(权重、逻辑运算、掩码、位分解、数字水印)
  • 【 Docker: 数据卷挂载】
  • windows上的静态链接和动态链接的区别与作用(笔记)
  • MySQL和Postgresql数据库备份和恢复
  • 使用MCU上的I2C总线进行传感器应用
  • 汽车标定技术(七)--基于模型开发如何生成完整的A2L文件(2)
  • ZZ308 物联网应用与服务赛题第E套
  • web相关框架
  • 安装dubbo-admin报错node版本和test错误
  • HTML使用canvas绘制海报(网络图片)
  • 20道高频JavaScript面试题快问快答
  • 【STM32】HAL库UART含校验位的串口通信配置BUG避坑
  • Python实用技巧:将 Excel转为PDF
  • 【面经】讲一下你对jvm和jmm的了解
  • 《网络协议》03. 传输层(TCP UDP)
  • ZooKeeper调优
  • 改进YOLOv5:结合ICCV2023|动态蛇形卷积,构建不规则目标识别网络