当前位置: 首页 > article >正文

PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)# 定义卷积神经网络模型
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 构建网络结构:3个卷积层+池化层组合,2个全连接层self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5MaxPool2d(2),                   # 最大池化,步长2Conv2d(32, 32, 5, padding=2),   # 第二层卷积MaxPool2d(2),                   # 第二次池化Conv2d(32, 64, 5, padding=2),   # 第三层卷积MaxPool2d(2),                   # 第三次池化Flatten(),                      # 将多维张量展平为向量Linear(1024, 64),               # 全连接层,输入1024维,输出64维Linear(64, 10),                 # 输出层,10个类别对应10个输出)def forward(self, x):# 定义前向传播路径x = self.model1(x)return x# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)# 训练20个完整轮次
for epoch in range(20):running_loss = 0.0  # 初始化本轮累计损失# 遍历数据加载器中的每个批次for data in dataloader:imgs, targets = data  # 获取图像和标签outputs = ty(imgs)    # 前向传播result_loss = loss(outputs, targets)  # 计算损失optim.zero_grad()     # 梯度清零,防止累积result_loss.backward()  # 反向传播计算梯度optim.step()          # 更新模型参数running_loss += result_loss  # 累加损失值# 打印本轮训练的累计损失print(f"Epoch {epoch+1}, Loss: {running_loss}")

http://www.lryc.cn/news/2399314.html

相关文章:

  • 07 APP 自动化- appium+pytest+allure框架封装
  • Postgresql常规SQL语句操作
  • 智能合约安全漏洞解析:从 Reentrancy 到 Integer Overflow
  • 英国2025年战略防御评估报告:网络与电磁域成现代战争核心
  • 基于QPSK调制解调+Polar编译码(SCL译码)的matlab性能仿真,并对比BPSK
  • go语言学习 第5章:函数
  • Qt Quick快速入门笔记
  • 《波段操盘实战技法》速读笔记
  • Glide NoResultEncoderAvailableException异常解决
  • 工厂模式与多态结合
  • 无人机巡检智能边缘计算终端技术方案‌‌——基于EFISH-SCB-RK3588工控机/SAIL-RK3588核心板的国产化替代方案‌
  • 相机--相机成像原理和基础概念
  • 2025-0604学习记录17——文献阅读与分享(2)
  • 图解浏览器多进程渲染:从DNS到GPU合成的完整旅程
  • 【计算机网络】第3章:传输层—TCP 拥塞控制
  • idea不识别lombok---实体类报没有getter方法
  • 【Hive入门】
  • 亚马逊站内信规则2025年重大更新:避坑指南与合规策略
  • 01 - AI 时代的操作系统课 [2025 南京大学操作系统原理]
  • 数组1 day7
  • SAP学习笔记 - 开发15 - 前端Fiori开发 Boostrap,Controls,MVC(Model,View,Controller),Modules
  • Redis中的过期策略与内存淘汰策略
  • 基于SDN环境下的DDoS异常攻击的检测与缓解
  • HarmonyOS 实战:给笔记应用加防截图水印
  • 如何轻松地将文件从 PC 传输到 iPhone?
  • 前端面试二之运算符与表达式
  • 【运维实战】使用Nvm配置多Node.js环境!
  • Bresenham算法
  • 【从GEO数据库批量下载数据】
  • day 44