当前位置: 首页 > news >正文

pytorch学习9-优化器学习

系列文章目录

  1. pytorch学习1-数据加载以及Tensorboard可视化工具
  2. pytorch学习2-Transforms主要方法使用
  3. pytorch学习3-torchvisin和Dataloader的使用
  4. pytorch学习4-简易卷积实现
  5. pytorch学习5-最大池化层的使用
  6. pytorch学习6-非线性变换(ReLU和sigmoid)
  7. pytorch学习7-序列模型搭建
  8. pytorch学习8-损失函数与反向传播
  9. pytorch学习9-优化器学习
  10. pytorch学习10-网络模型的保存和加载
  11. pytorch学习11-完整的模型训练过程

文章目录

  • 系列文章目录
  • 一、优化器使用
  • 总结


一、优化器使用

import torch.optim
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Sequential, Flatten, Linear
from torch.utils.data import DataLoaderdataset=torchvision.datasets.CIFAR10("data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=1)class Mynn(nn.Module):#这是使用序列的方法:def __init__(self):super(Mynn,self).__init__()self.model1=Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x=self.model1(x)return xmynn=Mynn()
loss=nn.CrossEntropyLoss()
optim=torch.optim.SGD(mynn.parameters(),lr=0.01)#调用SGD优化器。第一个参数是把模型的参数全输入进去for epoch in range(20):runing_loss=0.0for data in dataloader:imgs,target=dataoutputs=mynn(imgs)result_loss=loss(outputs,target)optim.zero_grad()#将上一次的梯度设置为0,这一步必须做result_loss.backward()#反向传播,计算出模型的参数optim.step()#进行完反向传播之后,模型的参数就计算出来了,就可以调用优化器了runing_loss=runing_loss+result_lossprint(runing_loss)#查看每一轮的损失之和

总结

以上就是今天要讲的内容,优化器的使用

http://www.lryc.cn/news/253981.html

相关文章:

  • MySQL之锁
  • 今日现货黄金最新建议
  • 基于混沌算法的图像加密解密系统
  • vscode插件离线下载
  • 第二十一章总结
  • 查看端口占用并杀死进程
  • 前后端数据传输格式(上)
  • maven的package和install命令有什么区别以及Maven常用命令与GAV坐标与Maven依赖范围与Maven依赖传递与依赖排除与统一声明版本号
  • 【动手学深度学习】(六)权重衰退
  • 动手学习深度学习-跟李沐学AI-自学笔记(3)
  • 3.2 Puppet 和 Chef 的比较与应用
  • promise使用示例
  • 一起学docker系列之十四Dockerfile微服务实践
  • Qt Creator 11.0.3同时使用Qt6.5和Qt5.14.2
  • Python中字符串列表的相互转换详解
  • 09、pytest多种调用方式
  • 分布式锁常见实现方案
  • 26、pytest使用allure解读
  • Uncle Maker: (Time)Stamping Out The Competition in Ethereum
  • 浅谈可重入与线程安全
  • 深入理解TDD(测试驱动开发):提升代码质量的利器
  • pyqt5使用pyqtgraph实现动态热力图
  • 【android开发-16】android中文件和sharedpreferences数据存储详解
  • 《当代家庭教育》期刊论文投稿发表简介
  • 【操作教程】如何将外省医保转入广州市区(医保转移接续手续办理)?
  • 【分布式系统学习】CAP原理详解
  • 【聚类】K-modes和K-prototypes——适合离散数据的聚类方法
  • Python-炸弹人【附完整源码】
  • [英语学习][5][Word Power Made Easy]的精读与翻译优化
  • Apache Doris 详细教程(一)