当前位置: 首页 > news >正文

【python】pytorch包(第二章)API使用与介绍

1> nn.Module (用于构建模型的底层逻辑)

介绍

nn.Module 是 torch.nn 中的一个类,是pytorch中自定义网络的基类

  1. __init__需要调用super方法,继承父类属性和方法
  2. forward方法必须实现,用来定义网络的向前计算的过程

例:y = w*x + b 的拟合模型

构建

from torch import nn
class Lr(nn.Module): #构建模型逻辑def __init__(self): #定义该层super(Lr,self).__init__() #继承父类的init参数self.linear = nn.Linear( aa , bb ) #该层网络的输入数据的维度为aa,输出数据的维度为bbdef forward(self,x): #即 如何由输入的数据x得到输出的结果outout = self.linear(x)return out

使用

#实例化模型
model = Lr()
#传入数据,计算结果
pred_y = model(x)

2> 优化器类 optimizer

介绍

优化器是torch为我们封装的用来更新参数的方法

设定优化器

  1. torch.optim.SGD(参数, lr=学习率)
    SGD(stochastic gradient descent, 随机梯度下降)
    ”参数“指: 模型中需要被更新的参数;
    ”参数“一般用model.parameters()函数来获取,会获取所有requires_grad=True的参数
    ”学习率“:默认为0.001

  2. torch.optim.Adam(参数, lr=学习率)

使用优化器

1. 步骤:
step 1. 优化器实例化
step 2. 将所有参数的梯度的值,初始化为0
step 3. 反向传播,更新梯度的值
step 4. 参数值被更新
2. 代码样例:

import optim from torch
#step 1. 优化器实例化
optimizer = optim.SGD(model.parameters(),lr=1e-3)
#待更新参数为model.parameters()
#学习率learning rate = 1e-3
#step 2. 将所有参数的梯度的值,初始化为0
optimizer.zero_grad() #参数归零函数
#step 3. 反向传播,更新梯度的值
loss.backward()
#step 4. 更新参数值
optimizer.step()

优化器的算法介绍

1> 梯度下降法

(1) BGD 梯度下降法 (batch gradient descent)

每次迭代都将所有样本送入,将全局样本的均值作为参考。
简称为:全局优化
缺点: 每次都要跑全部样本,速度慢

(2) SGD 随机梯度下降法(Stochastic gradient descent)

每次从所有样本中,随机抽取一个样本进行学习
优点: 解决了BGD算法 速度慢的问题
缺点: 可能被某个单个异常数据点影响
Python的torch包中的API调用方法: torch.optim.SGD()

(3) MBGD 小批量梯度下降法(Mini-batch gradient descent)

介于(1)和(2)之间的算法,每次选取一组样本进行学习

梯度下降法的劣势:

过于依赖于合适的学习率
学习率较小时,会导致收敛速度慢;
学习率较大时,会导致有可能跳过最优解,在最值点左右摆动幅度较大

2> AdaGrad

采取动态调整学习率的方法,解决梯度下降法的劣势
【个人理解:就是把 爬山算法 换成了 模拟退火算法

3> 动量法 和 RMSProp算法

采取动态调整梯度的移动指数,解决梯度下降法的劣势
【个人理解:也是把 爬山算法 换成了 模拟退火算法

4> Adam算法

相当于 AdaGrad法 和 RMSProp法 的结合
优势 更快达到最优解
劣势 有可能学习得更慢(因为最优解很难找到,而前面的算法不一定会找到最优解,而是误差较大的最优解)
Python的torch包中的API调用方法: torch.optim.Adam()

这下就可以看懂第一章的线性回归代码的意思是什么了

http://www.lryc.cn/news/70069.html

相关文章:

  • Linux驱动基础(SR501人体感应模块)
  • Android Studio Flamingo (火烈鸟) 升级踩坑记录
  • 【JAVA凝气】异常篇
  • C++中的函数模板
  • MapReduce【Shuffle-Combiner】
  • postman接口自动化测试
  • 历经70+场面试,我发现了大厂面试的套路都是···
  • 可视区域兼容性问题的思考及方法封装
  • 安全工具 | CMSeeK [指纹识别]
  • Android新logcat使用技巧
  • 使用Makefile笔记总结
  • npm下载依赖项目跑不起来--解决方案
  • SolVES模型生态系统服务功能社会价值评估
  • Godot引擎 4.0 文档 - 入门介绍 - 学习新功能
  • 如何进行MySQL漏洞扫描
  • C语言函数大全-- x 开头的函数(3)
  • 计算机图形学-GAMES101-12阴影
  • iOS_Swift高阶函数
  • 探索Vue的组件世界-组件复用
  • OMA通道-2
  • SAP 用CO13冲销工序报工,但是没有产生货物移动(TCODE:CO1P 、 SE38 :CORUPROC,CORUAFWP)
  • 信息收集-服务器信息
  • 连续签到积分兑换试用流量主小程序开发
  • C语言—自定义类型(结构体、枚举、联合)
  • Node.js博客项目开发思路笔记
  • python 之 shutil 文件的复制、删除、移动文件以及目录,并支持文件的归档、压缩和解压
  • jface
  • 六级备考28天|CET-6|听力第一讲|基本做题步骤与方法|13:30~14:30
  • 系统设计 - 设计一个速率限制器
  • [技术分享]Android平台实时音视频录像模块设计之道