当前位置: 首页 > news >正文

丹摩 | 基于PyTorch的CIFAR-10图像分类实现


从创建实例开始的新项目流程

第一步:创建实例

  1. 登录 DAMODEL 平台。
  2. 创建一个 GPU 实例:
    在这里插入图片描述
    • GPU 配置:选择 NVIDIA H800 或其他可用高性能 GPU。
      在这里插入图片描述

    • 系统配置:推荐使用 Ubuntu 20.04,内存 16GB,硬盘 50GB。

    • 启动实例后,获取实例的 IP 地址。

    • 选择镜像
      在这里插入图片描述


第二步:连接实例

在这里插入图片描述

  1. 登录成功后,你会进入实例的终端界面。
    在这里插入图片描述
    在这里插入图片描述

第三步:更新系统和安装基础工具

  1. 更新系统:

    sudo apt update && sudo apt upgrade -y
    
  2. 安装 Python 和基础工具:

    sudo apt install python3 python3-pip git -y
    
  3. (可选)安装文本编辑器:

    sudo apt install vim nano -y
    

第四步:创建项目目录并配置环境

  1. 创建项目目录:

    mkdir ~/workspace/cifar10_project
    cd ~/workspace/cifar10_project
    
  2. 创建并激活虚拟环境:

    python3 -m venv venv
    source venv/bin/activate
    

    在这里插入图片描述
    前面出现venu则表示已经激活虚拟环境了

  3. 安装必要的 Python 包:

    pip install torch torchvision matplotlib
    

在这里插入图片描述

第五步:下载数据并初始化项目代码

  1. 创建 Python 脚本:

    vim train_cifar10.py
    
  2. 在文件中输入以下代码,加载 CIFAR-10 数据集并定义简单模型:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.nn as nn
    import torch.optim as optim# 数据预处理
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])# 加载 CIFAR-10 数据集
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)# 定义简单卷积神经网络
    class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32 * 16 * 16, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = x.view(-1, 32 * 16 * 16)x = self.fc1(x)return x# 初始化模型、损失函数和优化器
    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# 模型训练
    for epoch in range(5):  # 训练 5 个周期running_loss = 0.0for inputs, labels in trainloader:optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")print("Finished Training")
    
  3. 保存并退出(按下 Esc,然后输入 :wq)。


第六步:运行训练脚本

运行脚本进行模型训练:

python train_cifar10.py
  • 脚本会下载 CIFAR-10 数据集并训练模型。
  • 训练完成后会输出每个 epoch 的损失值。
    在这里插入图片描述

第七步:保存和测试模型

  1. 保存模型:在脚本末尾添加代码以保存训练好的模型:

    torch.save(net.state_dict(), "cifar10_model.pth")
    print("Model saved as cifar10_model.pth")
    
  2. 重新运行脚本以保存模型:

    python train_cifar10.py
    
  3. 检查是否生成了 cifar10_model.pth 文件:

    ls
    
  4. 测试模型(可选):加载保存的模型并在测试集上评估准确率:

    net.load_state_dict(torch.load("cifar10_model.pth"))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():for inputs, labels in testloader:outputs = net(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on test dataset: {100 * correct / total}%")
    

第八步:清理和扩展

  1. 扩展功能

    • 使用更复杂的模型(如 ResNet)。
    • 尝试使用 Adam 优化器提高性能。
    • 可视化训练过程或模型预测结果。
  2. 清理资源

    • 如果完成训练并不再需要 GPU 计算,记得停止或删除实例以节省费用。

\

http://www.lryc.cn/news/490892.html

相关文章:

  • C#变量和函数如何和unity组件绑定
  • AI模型---安装cuda与cuDNN
  • 【大数据学习 | Spark-Core】Spark提交及运行流程
  • 内网渗透横向移动1
  • 现代密码学
  • Pod 动态分配存储空间实现持久化存储
  • Jackson、Gson、FastJSON三款JSON利器比拼
  • php:nginx如何配置WebSocket代理?
  • 3349、检测相邻递增子数组 Ⅰ
  • C++笔记之函数入参传递std::unique_ptr 时使用 std::move的场景
  • 怎么只提取视频中的声音?从视频中提取纯音频技巧
  • 数仓工具—Hive语法之窗口函数中的 case when
  • 基于微信小程序的酒店客房管理系统+LW示例参考
  • Elasticsearch客户端在和集群连接时,如何选择特定的节点执行请求的?
  • 【AI最前线】DP双像素sensor相关的AI算法全集:深度估计、图像去模糊去雨去雾恢复、图像重建、自动对焦
  • CTF之密码学(Polybius密码)
  • 【C++篇】从售票窗口到算法核心:C++队列模拟全解析
  • clipboard
  • 【Mac】VMware Fusion Pro 安装 CentOS 7
  • 游戏引擎学习第22天
  • 洛谷 B2038:奇偶 ASCII 值判断
  • APIRouter
  • 算法模板2:位运算+离散化+区间合并
  • 钉钉授权登录
  • 【视频】二维码识别:libzbar-dev、zbar-tools(zbarimg )
  • C语言中的结构体,指针,联合体的使用
  • 基于卡尔曼滤波器的 PID 控制
  • CVE-2022-26201
  • 海信Java后端开发面试题及参考答案
  • 传智杯 3-初赛:终端