当前位置: 首页 > news >正文

Pytorch个人学习记录总结 10

目录

优化器


优化器

官方文档地址:torch.optimicon-default.png?t=N6B9https://pytorch.org/docs/stable/optim.html 

Debug过程中查看的grad所在的位置:

model --> Protected Atributes --> _modules --> ‘model’ --> Protected Atributes --> _modules --> ‘0’(任选一个conv层) --> weight(查看weight下的data和grad的变化)

 简易训练代码,添加了Loss、Optim。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transformsdataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model = Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):  # 模型前向传播return self.model(x)model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数
optim = torch.optim.SGD(model.parameters(), lr=0.01)  # lr不能过大或者过小。刚开始的lr可设置得较大一点,后面再对lr进行调节
len = len(dataloader)for epoch in range(20):total_loss = 0.0for imgs, targets in dataloader:outputs = model(imgs)res_loss = loss_cross(outputs, targets)optim.zero_grad()  # 优化器对model中的每一个参数进行梯度清零res_loss.backward()  # 损失反向传播optim.step()  # 对model参数开始调优total_loss += res_lossprint('epoch:{}\ttotal_loss:{}\tmean_loss:{}.'.format(epoch, total_loss, total_loss / len))
# epoch:0	total_loss:9374.806640625	mean_loss:1.8749613761901855.
# epoch:1	total_loss:7721.240234375	mean_loss:1.544248104095459.
# epoch:2	total_loss:6830.775390625	mean_loss:1.3661550283432007.

http://www.lryc.cn/news/106211.html

相关文章:

  • 18款奔驰S320升级后排座椅加热功能,提升后排乘坐舒适性
  • Vue中的插值表达式
  • 背包问题(模板)
  • docker容器创建私有仓库(第三篇)
  • Eureka 学习笔记4:客户端 DiscoveryClient
  • 【方法】PDF可以转换成Word文档吗?如何操作?
  • AlphaControls crack
  • 论文笔记——Influence Maximization in Undirected Networks
  • Stable Diffusion - SDXL 1.0 全部样式设计与艺术家风格的配置与提示词
  • Hbase pe 压测 OOM问题解决
  • 问题解决——datagrip远程连接虚拟机中ubuntu的mysql失败
  • 【晚风摇叶之随机密码生成器】随机生成密码
  • Spring Cache
  • em3288 linux_4.19 sd卡调试
  • 前端vue uni-app cc-countdown倒计时组件
  • fifo读写的数据个数
  • Java之Map接口
  • windows系统中的命令行可以用python,pip等命令(已在系统中添加过python环境变量),但是pycharm的terminal中无法使用。
  • 编译 OneFlow 模型
  • 【kubernetes】k8s单master集群环境搭建及kuboard部署
  • 0802|IO进程线程 day5 进程概念
  • 4 Promethues监控主机和容器
  • 亚马逊买家账号ip关联怎么处理
  • NO4 实验四:生成Web工程
  • 【linux】进程
  • 电商高并发设计之SpringBoot整合Redis实现布隆过滤器
  • SpringBoot第25讲:SpringBoot集成MySQL - MyBatis 注解方式
  • 服务器返回 413 Request Entity Too Large
  • 如何一目了然地监控远程 Linux 系统
  • 9.环境对象和回调函数