当前位置: 首页 > news >正文

创建自定义Dataset类与多分类问题实战

codes

文章目录

  • 🌟 6 多分类问题与卷积模型的优化
    • 🧩 6.1 创建自定义Dataset类
      • ⚠️ 数据集特点:
      • 🔑 关键实现步骤:
      • 🛠️ 自定义Dataset类实现
      • 📊 数据集划分与可视化
    • 🧠 6.2 基础卷积模型
      • 📐 网络结构设计
      • ⚙️ 训练配置
      • 🔁 训练与测试函数
      • 📈 模型训练与评估
      • 📉 结果可视化
    • 🚨 关键问题:过拟合现象

🌟 6 多分类问题与卷积模型的优化

数据集:Multi-class Weather Dataset
注:与第5章使用的数据集不同,本数据集为多分类任务且所有图片存储在同一文件夹

# 导入基础库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import glob  # 文件路径操作
from torchvision import transforms  # 数据预处理
from PIL import Image  # 图像处理
from torch.utils import data  # 数据集构建

🧩 6.1 创建自定义Dataset类

⚠️ 数据集特点:

  1. 多分类标签(非二分类)
  2. 所有图片存储在单一文件夹
  3. 未划分训练集/测试集

不能使用 torchvision.datasets.ImageFolder 加载该数据集
✅ 需通过继承 torch.utils.data.Dataset 实现自定义数据集类

🔑 关键实现步骤:

# 获取所有图片路径
imgs = glob.glob(r'D:/my_all_learning/dataset2/dataset2/*.jpg') 
print(imgs[:3])  # 查看前3个路径# 定义类别映射
species = ['cloudy','rain','shine','sunrise']  # 4个类别
species_to_idx = dict((c,i) for i,c in enumerate(species))  # 类别→索引
idx_to_species = dict((i,c) for i,c in enumerate(species))  # 索引→类别
print(species_to_idx)
print(idx_to_species)# 生成标签列表
labels = []
for img in imgs:for i,c in enumerate(species):if c in img:  # 根据路径名判断类别labels.append(i)
print(labels[:3])# 定义图像预处理
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

🛠️ 自定义Dataset类实现

class WT_Dataset(data.Dataset):def __init__(self, imgs_path, labels):self.imgs_path = imgs_pathself.labels = labelsdef __len__(self):return len(self.imgs_path)def __getitem__(self, index):img_path = self.imgs_path[index]label = self.labels[index]pil_img = Image.open(img_path)pil_img = pil_img.convert('RGB')  # 确保RGB格式pil_img = transform(pil_img)  # 应用预处理return pil_img, label

📊 数据集划分与可视化

# 创建数据集实例
dataset = WT_Dataset(imgs, labels)
print(f"数据集总量: {len(dataset)}")# 划分训练集(80%)和测试集(20%)
train_count = int(0.8 * len(dataset))
test_count = len(dataset) - train_count
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])
print(f"训练集: {len(train_dataset)}, 测试集: {len(test_dataset)}")# 创建DataLoader
BATCH_SIZE = 16
train_dl = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)# 可视化首个batch的图像
imgs_batch, labels_batch = next(iter(train_dl))
plt.figure(figsize=(12,8))
for i, (img, label) in enumerate(zip(imgs_batch[:6], labels_batch[:6])):img = (img.permute(1,2,0).numpy() + 1)/2  # 反归一化+通道重排plt.subplot(2,3,i+1)plt.title(idx_to_species.get(label.item()))  # 显示类别名plt.imshow(img)

🧠 6.2 基础卷积模型

📐 网络结构设计

class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 卷积层定义self.conv1 = nn.Conv2d(3, 16, 3)    # 输入通道3, 输出16, 卷积核3x3self.conv2 = nn.Conv2d(16, 32, 3)   # 通道数16→32self.conv3 = nn.Conv2d(32, 64, 3)   # 通道数32→64# 全连接层定义self.fc1 = nn.Linear(64*10*10, 1024)  # 展平后输入self.fc2 = nn.Linear(1024, 4)        # 输出4分类def forward(self, x):# [batch, 3, 96, 96] → 卷积1 → [batch, 16, 94, 94]x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)  # → [batch, 16, 47, 47]# → 卷积2 → [batch, 32, 45, 45]x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)  # → [batch, 32, 22, 22] (45/2取整)# → 卷积3 → [batch, 64, 20, 20]x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2)  # → [batch, 64, 10, 10]# 展平 → 全连接x = x.view(-1, 64*10*10)x = F.relu(self.fc1(x))  # → [batch, 1024]x = self.fc2(x)          # → [batch, 4]return x

⚙️ 训练配置

# 设备选择
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")# 模型初始化
model = Net().to(device)
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Adam优化器

🔁 训练与测试函数

def train(dataloader, model, loss_fn, optimizer):model.train()size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 前向传播pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 统计指标correct += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_loss /= num_batchescorrect /= sizereturn train_loss, correctdef test(dataloader, model):model.eval()size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizereturn test_loss, correct

📈 模型训练与评估

epochs = 20
train_loss, train_acc = [], []
test_loss, test_acc = [], []for epoch in range(epochs):# 训练周期epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)# 测试周期epoch_test_loss, epoch_test_acc = test(test_dl, model)# 记录指标train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 打印日志template = ("epoch:{:2d}, train_loss:{:.5f}, train_acc:{:.1f}%, ""test_loss:{:.5f}, test_acc:{:.1f}%")print(template.format(epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))print("训练完成!")

📉 结果可视化

# 损失曲线
plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
plt.title("训练与测试损失对比")
plt.show()# 准确率曲线
plt.plot(range(1, epochs+1), train_acc, label='train_acc')
plt.plot(range(1, epochs+1), test_acc, label='test_acc')
plt.legend()
plt.title("训练与测试准确率对比")
plt.show()

🚨 关键问题:过拟合现象

下一讲将介绍卷积网络优化技术(Dropout、BN、学习率衰减)提升泛化能力


关键词:多分类 卷积神经网络 自定义数据集 过拟合 PyTorch

http://www.lryc.cn/news/587492.html

相关文章:

  • 怎么解决数据库幻读问题
  • 【图片识别改名】水印相机拍的照片如何将照片的名字批量改为水印内容?图片识别改名的详细步骤和注意事项
  • 设计模式笔记_结构型_桥接模式
  • vscode 安装 esp ide环境
  • 基于MATLAB的LSTM长短期记忆神经网络的数据回归预测方法应用
  • 02 51单片机之LED闪烁
  • 前端同学,你能不能别再往后端传一个巨大的JSON了?
  • 构建完整工具链:GCC/G++ + Makefile + Git 自动化开发流程
  • 前端接入海康威视摄像头的三种方案
  • autoware激光雷达和相机标定
  • JAVA 设计模式 工厂
  • Docker搭建Redis分片集群
  • 鸿蒙应用开发: 鸿蒙项目中使用私有 npm 插件的完整流程
  • Kotlin集合接口
  • 常用的OTP语音芯片有哪些?
  • 前端性能与可靠性工程系列: 渲染、缓存与关键路径优化
  • Spring Boot - Spring Boot 集成 MyBatis 分页实现 PageHelper
  • 【React Native】环境变量和封装 fetch
  • 智源:LLM指令数据建设框架
  • VR样板间:房产营销新变革
  • Cesium 9 ,Cesium 离线地图本地实现与服务器部署( Vue + Cesium 多项目共享离线地图切片部署实践 )
  • 谷歌开源库gtest 框架安装与使用
  • VR全景制作流程?什么是全景?
  • ELK、Loki、Kafka 三种日志告警联动方案全解析(附实战 Demo)
  • EVOLVEpro安装使用教程-蛋白质语言模型驱动的快速定向进化
  • 快速搭建Maven仓库服务
  • 前端面试十二之vue3基础
  • 从代码学习深度强化学习 - DDPG PyTorch版
  • CCPD 车牌数据集提取标注,并转为标准 YOLO 格式
  • MySQL 分表功能应用场景实现全方位详解与示例