当前位置: 首页 > news >正文

pytorch神经网络训练(AlexNet)

  • 导包
import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): for file in os.listdir(label_dir): self.files.append(os.path.join(label_dir, file))self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) label = self.labels[idx] if self.transform: image = self.transform(image) return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([transforms.Resize((227, 227)),  # AlexNet的输入图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_featuresalexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度for images, labels in data_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn accuracy
  • 训练模型
num_epochs = 10for epoch in range(num_epochs):alexnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

http://www.lryc.cn/news/372282.html

相关文章:

  • 构建大语言模型友好型网站
  • Git代码冲突原理与三路合并算法
  • 聆思CSK6大模型开发板英语评测类开源SDK详解
  • 通用大模型VS垂直大模型,你更青睐哪一方?
  • Python第二语言(十四、高阶基础)
  • python脚本之调用其他目录脚本
  • C# 事件(Event)定义及其使用
  • 2.负载压力测试
  • 【AI工具】jupyter notebook和jupyterlab对比和安装
  • Linux 基本指令3
  • 在Linux系统中,可以使用OpenSSL来生成CSR(Certificate Signing Request)、PEM格式的公钥和PEM格式的私钥。
  • 【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 团队派遣(100分) - 三语言AC题解(Python/Java/Cpp)
  • Python数据分析与机器学习在医疗诊断中的应用
  • vite.config.js如何使用env的环境变量
  • MySql几十万条数据,同时新增或者修改
  • 如何提高MySQL DELETE 速度
  • 本地Zabbix开源监控系统安装内网穿透实现远程访问详细教程
  • 从Android刷机包提取System和Framework
  • 分布式光纤测温DTS与红外热成像系统的主要区别是什么?
  • python数据分析-问卷数据分析(地理课)
  • 【ARM64 常见汇编指令学习 19.3 -- ARMv8 三目运算指令 csel 详细介绍】
  • Docker 安装部署(CentOS 8)
  • Python自动化
  • 自然语言处理领域的重大挑战:解码器 Transformer 的局限性
  • 【机器学习】机器学习赋能医疗健康:从诊断到治疗的智能化革命
  • Elasticsearch6.7版本,内网中其他电脑无法连接
  • 交友系统定制版源码 相亲交友小程序源码全开源可二开 打造独特的社交交友系统
  • 数据结构笔记39-48
  • 2-3 基于matlab的NSCT-PCNN融合和创新算法(NSCT-ML-PCNN )图像融合
  • 机器学习笔记 - LoRA:大型语言模型的低秩适应