当前位置: 首页 > news >正文

深度学习——划分自定义数据集

深度学习——划分自定义数据集

以人脸表情数据集raf_db为例,初始目录如下:
在这里插入图片描述
需要经过处理后返回

train_images, train_label, val_images, val_label

定义 read_split_data(root: str, val_rate: float = 0.2) 方法来解决,代码如下:

# root:数据集所在路径
# val_rate:划分测试集的比例def read_split_data(root: str, val_rate: float = 0.2):random.seed(0)  # 保证随机结果可复现assert os.path.exists(root), "dataset root: {} does not exist.".format(root)# 遍历文件夹,一个文件夹对应一个类别file_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]# 排序,保证各平台顺序一致file_class.sort()# 生成类别名称以及对应的数字索引class_indices = dict((k, v) for v, k in enumerate(file_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images = []  # 存储训练集的所有图片路径train_label = []  # 存储训练集图片对应索引信息val_images = []  # 存储验证集的所有图片路径val_label = []  # 存储验证集图片对应索引信息every_class_num = []  # 存储每个类别的样本总数supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型# 遍历每个文件夹下的文件for cla in file_class:cla_path = os.path.join(root, cla)# 遍历获取supported支持的所有文件路径images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]# 排序,保证各平台顺序一致images.sort()# 获取该类别对应的索引image_class = class_indices[cla]# 记录该类别的样本数量every_class_num.append(len(images))# 按比例随机采样验证样本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集val_images.append(img_path)val_label.append(image_class)else:  # 否则存入训练集train_images.append(img_path)train_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))print("{} images for training.".format(len(train_images)))print("{} images for validation.".format(len(val_images)))assert len(train_images) > 0, "number of training images must greater than 0."assert len(val_images) > 0, "number of validation images must greater than 0."return train_images, train_label, val_images, val_label

此时可通过以下代码获得训练集和测试集数据:

train_images, train_label, val_images, val_label = read_split_data(data_path)

完结撒花。

http://www.lryc.cn/news/106238.html

相关文章:

  • Jmeter性能测试之正则表达式提取器
  • 浅谈Kubernetes中Service网络实现(服务发现)
  • 【重造轮子】golang实现可重入锁
  • torch显存分析——对生成模型清除显存
  • electron+vue+ts窗口间通信
  • 基于Fringe-Projection环形投影技术的人脸三维形状提取算法matlab仿真
  • 如何使用Webman框架实现多语言支持和国际化功能?
  • 接受平庸,特别是程序员
  • HTML兼容性
  • Java日期和时间处理入门指南
  • anndata k折交叉
  • 深入解析项目管理中的用户流程图
  • Vue使用QrcodeVue生成二维码并下载
  • “用户登录”测试用例总结
  • 适应于Linux系统的三种安装包格式 .tar.gz、.deb、rpm
  • Linux lvs负载均衡
  • Tomcat 创建https
  • 超导电性的基本现象和相关理论
  • 在 PHP 中单引号(‘ ‘)和双引号(“ “)用法的区别
  • SpringCloudAlibaba:服务网关之Gateway的cors跨域问题
  • react中的高阶组件理解与使用
  • “从零开始学习Spring Boot:构建高效的Java应用程序“
  • 容器部署jenkins定时构建于本地时间不一致
  • 生成指定网段的IP字典自动化脚本
  • Java版工程行业管理系统源码-专业的工程管理软件- 工程项目各模块及其功能点清单 em
  • 《向量数据库指南》——大模型时代,为什么向量数据库成为标配?
  • Pytorch个人学习记录总结 10
  • 18款奔驰S320升级后排座椅加热功能,提升后排乘坐舒适性
  • Vue中的插值表达式
  • 背包问题(模板)