当前位置: 首页 > news >正文

学习pytorch15 优化器

优化器

  • 官网
  • 如何构造一个优化器
  • 优化器的step方法
  • code
  • running log
    • 出现下面问题如何做反向优化?

官网

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述
提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题
优化模型参数

如何构造一个优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # momentum SGD优化算法用到的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)
  1. 选择一个优化器算法,如上 SGD 或者 Adam
  2. 第一个参数 需要传入模型参数
  3. 第二个及后面的参数是优化器算法特定需要的,lr 学习率基本每个优化器算法都会用到

优化器的step方法

会利用模型的梯度,根据梯度每一轮更新参数
optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题

for input, target in dataset:optimizer.zero_grad()  # 必须做 把上一轮计算的梯度清零,否则模型会有问题output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

or 把模型梯度包装成方法再调用

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

code

import torch
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 定义loss
loss = nn.CrossEntropyLoss()
# 搭建网络
myseq = MySeq()
print(myseq)
# 定义优化器
optmizer = optim.SGD(myseq.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = data# print(imgs.shape)output = myseq(imgs)optmizer.zero_grad()  # 每轮训练将梯度初始化为0  上一次的梯度对本轮参数优化没有用result_loss = loss(output, targets)result_loss.backward()  # 优化器需要每个参数的梯度, 所以要在backward() 之后执行optmizer.step()  # 根据梯度对每个参数进行调优# print(result_loss)# print(result_loss.grad)# print("ok")running_loss += result_lossprint(running_loss)

running log

loss由小变大最后到nan的解决办法:

  1. 降低学习率
  2. 使用正则化技术
  3. 增加训练数据
  4. 检查网络架构和激活函数

出现下面问题如何做反向优化?

Files already downloaded and verified
MySeq((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)
tensor(18622.4551, grad_fn=<AddBackward0>)
tensor(16121.4092, grad_fn=<AddBackward0>)
tensor(15442.6416, grad_fn=<AddBackward0>)
tensor(16387.4531, grad_fn=<AddBackward0>)
tensor(18351.6152, grad_fn=<AddBackward0>)
tensor(20915.9785, grad_fn=<AddBackward0>)
tensor(23081.5254, grad_fn=<AddBackward0>)
tensor(24841.8359, grad_fn=<AddBackward0>)
tensor(25401.1602, grad_fn=<AddBackward0>)
tensor(26187.4961, grad_fn=<AddBackward0>)
tensor(28283.8633, grad_fn=<AddBackward0>)
tensor(30156.9316, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
http://www.lryc.cn/news/222979.html

相关文章:

  • [算法日志]图论刷题 沉岛思想的运用
  • Web服务器的搭建
  • 如何使用 GTX750 或 1050 显卡安装 CUDA11+
  • 跟着森老师学React Hooks(1)——使用Vite构建React项目
  • 强力解决使用node版本管理工具 NVM 出现的问题(找不到 node,或者找不到 npm)
  • Docker指定容器使用内存
  • 做什么数据表格啊,要做就做数据可视化
  • CSS特效003:太阳、地球、月球的旋转
  • 云计算的大模型之争,亚马逊云科技落后了?
  • 【form校验】3.0项目多层list嵌套
  • 公共功能测试用例
  • 【电路笔记】-并联RLC电路分析
  • ros1 client
  • 射频功率放大器应用中GaN HEMT的表面电势模型
  • CSP(Common Spatial Patterns)——EEG特征提取方法详解
  • 【Git】Git 学习笔记_操作本地仓库
  • 杂记(3):在Pytorch中如何操作将数据集分为训练集和测试集?
  • 【MySQL篇】数据库角色
  • c++ 信奥赛编程 2050:【例5.20】字串包含
  • 用dbeaver创建一个enum类型,并讲述一部分,mysql的enum类型的知识
  • Paste v4.1.2(Mac剪切板)
  • 事件绑定-回调函数
  • Makefile 总述
  • 写给新用户-Mac软件指南篇:让你的Mac更好用
  • 03运算符综合
  • LeetCode刷题--思路总结记录
  • Nodejs
  • 【面经】spring,springboot,springcloud有什么区别和联系
  • SpringBoot Kafka消费者 多kafka配置
  • git 标签相关命令