当前位置: 首页 > news >正文

Pytorch学习--神经网络--优化器

一、头文件

torch.optim.Optimizer(params, defaults)
optim文档

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

二、代码

不带优化器的代码框架

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d, Conv2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return xYorelee = Mary()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(Yorelee.parameters(),lr=0.01)for epoch in range(20):total_loss = 0for data in dataloader:img,target = dataoutput = Yorelee(img)# print(output)# print(target)result_loss = loss(output,target)# print(result_loss)# print("***********************")optim.zero_grad()result_loss.backward()optim.step()total_loss += result_lossprint(total_loss)

输出:

tensor(18861.5215, grad_fn=<AddBackward0>)
tensor(16226.8633, grad_fn=<AddBackward0>)
tensor(15367.2148, grad_fn=<AddBackward0>)
http://www.lryc.cn/news/478432.html

相关文章:

  • w~自动驾驶合集11
  • 大数据新视界 -- 大数据大厂之 Impala 性能优化:解锁大数据分析的速度密码(上)(1/30)
  • GESP4级考试语法知识(算法概论(三))
  • x-cmd pkg | gum - 轻松构建美观实用的终端界面,解锁命令行新玩法
  • WMS系统打通仓储全链条数据势在必行,该如何做呢
  • 基于Python的校园爱心帮扶管理系统
  • 如何基于pdf2image实现pdf批量转换为图片
  • Tomcat(1) 什么是Tomcat?
  • 商务礼仪与职场沟通
  • C语言必做30道练习题
  • Linux信号_信号的产生
  • 数据库基础(7) . DML-基本操作
  • windows运行ffmpeg的脚本报错:av_ts2str、av_ts2timestr、av_err2str => E0029 C4576
  • [mysql]mysql的DML数据操作语言增删改,以及新特性计算列,阿里巴巴开发手册mysql相关
  • Github 2024-11-07 Go开源项目日报 Top10
  • 【黑盒测试】等价类划分法及实例
  • LeetCode17. 电话号码的字母组合(2024秋季每日一题 59)
  • SQLite数据库是什么?DB Browser for SQLite是什么?
  • 核心概念解析Caffeine 缓存模型与策略
  • ubuntu 22.04 防火墙
  • 【数据结构-合法括号字符串】力扣678. 有效的括号字符串
  • ThreadX在STM32上的移植:F1,F4通用启动文件tx_initialize_low_level.s
  • 【算法】递归+深搜:814.二叉树剪枝
  • spring Framework 特定条件下目录遍历漏洞(CVE-2024-38816)修复
  • ESP32-C3 入门笔记03:VScode + flash_download_tool 下载烧录程序(ESP-IDF + PlatformIO)
  • Node.js——fs模块-文件重命名和移动
  • vue2.0版本引入Element-ui问题解决
  • qt QTableView详解
  • 将Notepad++添加到右键菜单【一招实现】
  • Nature Methods | 基于流形约束的RNA速度推断精准解析细胞周期动态调节规律