当前位置: 首页 > news >正文

深度学习-03 Pytorch

损失函数是用来衡量模型预测结果与真实值之间的差异,并用来优化模型的指标。在机器学习和神经网络中,常用的损失函数包括均方误差(Mean Squared Error,MSE)、交叉熵(Cross-Entropy)等。
反向传播(Backpropagation)是一种基于梯度下降算法的优化方法,用来计算模型中每个参数对于损失函数的梯度,从而更新参数以最小化损失函数。反向传播通过链式法则将损失函数的梯度传递回每个参数,从输出层向输入层反向计算梯度。具体而言,反向传播算法可以分为两个步骤:
前向传播(Forward Propagation):将输入数据通过模型的参数计算出预测值,并计算出损失函数的值。
反向传播(Backward Propagation):通过链式法则计算出损失函数对于每个参数的梯度,并更新参数。
反向传播的过程中,需要根据损失函数的类型来计算梯度。例如,对于均方误差损失函数,梯度的计算可以通过对每个参数的偏导数进行求解;对于交叉熵损失函数,梯度的计算可以通过softmax函数的导数进行求解。
反向传播算法的实现主要包括两个步骤:计算梯度和参数更新。在计算梯度时,通过对损失函数进行求导,得到每个参数的梯度;在参数更新时,根据梯度和学习率进行参数的更新。这个过程不断迭代,直到达到收敛条件或达到一定的迭代次数为止。

优化器是机器学习中一个重要的组件,用于调整模型的参数以使其最优化。在机器学习任务中,目标就是最小化或最大化一个特定的损失函数。优化器的作用就是通过调整模型的参数,使得损失函数的值最小化或最大化。
常见的优化器有梯度下降法(Gradient Descent)、随机梯度下降法(Stochastic Gradient Descent)、动量法(Momentum)、Nesterov Accelerated Gradient(NAG)、Adagrad、RMSprop、Adam等。

常见的Pytorch模型有:
1.线性回归模型(Linear Regression Model):用于拟合线性关系数据的模型。
2.逻辑回归模型(Logistic Regression Model):用于分类问题的模型。
3.多层感知机模型(Multi-Layer Perceptron Model):由多个全连接层组成的深度神经网络模型。
4.卷积神经网络模型(Convolutional Neural Network Model):用于处理图像和视觉数据的模型。
5.循环神经网络模型(Recurrent Neural Network Model):用于处理序列数据的模型。
6.长短期记忆网络模型(Long Short-Term Memory Model):一种循环神经网络的变种,用于处理长序列数据的模型。
7.生成对抗网络模型(Generative Adversarial Network Model):由生成器和判别器组成的模型,用于生成新的数据样本。
8.注意力机制模型(Attention Mechanism Model):用于处理序列数据的模型,通过对输入序列的不同部分赋予不同的注意权重来提升模型性能。
9.Transformer模型:基于注意力机制的模型,用于处理序列数据的模型,如自然语言处理任务中的机器翻译和文本生成等。

模型保存

vgg16=torchvision.models.vgg16(pretrained=False)# 保存方式一,保存模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")# 保存方式二,保存模型参数  (推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")# 保存方式一:加载模型
model=torch.load("vgg16_method1.pth")# 保存方式二:加载模型
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict()
model=torch.load("vgg16_method2.pth")
http://www.lryc.cn/news/442731.html

相关文章:

  • GRU(门控循环单元)的原理与代码实现
  • 【医疗大数据】医疗保健领域的大数据管理:采用挑战和影响
  • gevent + flask 接口会卡住
  • SpringCloud Alibaba五大组件之——Sentinel
  • brpc之io事件分发器
  • MySQL | 知识 | 从底层看清 InnoDB 数据结构
  • es的封装
  • 写一个自动化记录鼠标/键盘的动作,然后可以重复执行的python程序
  • Spring Boot-热部署问题
  • 深度学习——管理模型的参数
  • 芯片验证板卡设计原理图:372-基于XC7VX690T的万兆光纤、双FMC扩展的综合计算平台 RISCV 芯片验证平台
  • 【软设】 系统开发基础
  • Linux移植之系统烧写
  • 【数据结构与算法】LeetCode:双指针法
  • Istio下载及安装
  • Redis基础数据结构之 Sorted Set 有序集合 源码解读
  • 蓝队技能-应急响应篇Web内存马查杀JVM分析Class提取诊断反编译日志定性
  • 递归快速获取机构树型图
  • [Web安全 网络安全]-XSS跨站脚本攻击
  • 数据库数据恢复—SQL Server附加数据库出现“错误823”怎么恢复数据?
  • Vscode 中新手小白使用 Open With Live Server 的坑
  • 【深度学习 transformer】Transformer与ResNet50在自定义数据集图像分类中的效果比较
  • 【系统架构设计师】专业英语90题(附答案详解)
  • ItemXItemEffect | ItemEffect
  • web 动画库
  • 我的AI工具箱Tauri版-MicrosoftTTS文本转语音
  • 【Webpack--013】SourceMap源码映射设置
  • 创新驱动,技术引领:2025年广州见证汽车电子技术新高度
  • Spring Boot框架在心理教育辅导系统中的应用案例
  • Shiro-550—漏洞分析(CVE-2016-4437)