当前位置: 首页 > news >正文

.npy格式图像如何进行深度学习模型训练处理,亲测可行

import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npfrom torch.utils.data import DataLoader, Datasetfrom torchvision import transformsfrom PIL import Imageimport json# 加载训练集和测试集数据train_images = np.load('../dataset/train_image.npy')train_labels = np.load('../dataset/train_label_3.npy')test_images = np.load('../dataset/test_image.npy')test_labels = np.load('../dataset/test_label_3.npy')# 将one-hot编码的标签转换为整数索引train_labels = np.argmax(train_labels, axis=1)test_labels = np.argmax(test_labels, axis=1)# 确保图像数据是 uint8 类型train_images = (train_images * 255).astype(np.uint8)test_images = (test_images * 255).astype(np.uint8)# 定义数据集类class NumpyToPIL(object):def __call__(self, sample):return Image.fromarray(sample)class CustomImageDataset(Dataset):def __init__(self, images, labels, transform=None):self.images = imagesself.labels = labelsself.transform = transformdef __len__(self):return len(self.images)def __getitem__(self, idx):image = self.images[idx]label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 数据预处理和增强transform_train = transforms.Compose([NumpyToPIL(),transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])transform_test = transforms.Compose([NumpyToPIL(),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# 创建数据集和数据加载器#BATCH_SIZE = 32dataset_train = CustomImageDataset(train_images, train_labels, transform=transform_train)dataset_test = CustomImageDataset(test_images, test_labels, transform=transform_test)train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, drop_last=True)test_loader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)# 检查标签格式train_labels = train_labels.ravel()test_labels = test_labels.ravel()# 检查标签的唯一值,生成类别索引映射train_class_to_idx = {str(i): i for i in set(train_labels.tolist())}test_class_to_idx = {str(i): i for i in set(test_labels.tolist())}with open('train_class.txt', 'w') as file:file.write(str(train_class_to_idx))with open('train_class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(train_class_to_idx))with open('test_class.txt', 'w') as file:file.write(str(test_class_to_idx))with open('test_class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(test_class_to_idx))
http://www.lryc.cn/news/390554.html

相关文章:

  • XFeat快速图像特征匹配算法
  • 普元EOS学习笔记-低开实现图书的增删改查
  • 动态住宅代理IP详细解析
  • 等保2.0 实施方案之信息软件验证要求
  • 【LeetCode的使用方法】
  • 【SGX系列教程】(二)第一个 SGX 程序: HelloWorld,linux下运行
  • 网页报错dns_probe_possible 怎么办?——错误代码有效修复
  • Vue.js 中属性绑定的详细解析:冒号 `:` 和非冒号的区别
  • 使用Java实现智能物流管理系统
  • 深圳技术大学oj C : 生成r子集
  • 不同操作系统下的换行符
  • Transformation(转换)开发-switch/case组件
  • Android Gradle 开发与应用 (二): Android 项目结构与构建配置
  • 02:vim的使用和权限管控
  • GNeRF代码复现
  • EXCEL返回未使用数组元素(未使用值)
  • 系统调用简单介绍
  • Mac可以读取NTFS吗 Mac NTFS软件哪个好 mac ntfs读写工具免费
  • AI是否能够做决定
  • 【Excel操作】Python Pandas判断Excel单元格中数值是否为空
  • C# Opacity 不透明度
  • 推荐三款常用接口测试工具!
  • 【Qt】Qt多线程编程指南:提升应用性能与用户体验
  • PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解
  • C++Primer Plus 第十四章代码重用:编程练习,第4题
  • 01 Docker 概述
  • c++的const
  • Git不想跟踪某个文件
  • DB-GPT 文档切分报错
  • #如何使用 Qt 5.6 在 Android 上启用 NFC