当前位置: 首页 > news >正文

不同优化器的应用

 简单用用,优化器具体参考

深度学习中的优化器原理(SGD,SGD+Momentum,Adagrad,RMSProp,Adam)_哔哩哔哩_bilibili

收藏版|史上最全机器学习优化器Optimizer汇总 - 知乎 (zhihu.com)

import numpy as np
import matplotlib.pyplot as plt
import torch
# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])y_data = torch.tensor([[2.0], [4.0], [6.0]])loss_SGD = []
loss_Adagrad = []
loss_Adam = []
loss_Adamax = []
loss_ASGD = []
loss_LBFGS = []
loss_RMSprop = []
loss_Rprop = []class LinearModel(torch.nn.Module):def __init__(self):super().__init__()self.Linear = torch.nn.Linear(1,1)def forward(self,x):y_pred = self.Linear(x)return y_predmodel = LinearModel()criterion = torch.nn.MSELoss(reduction='sum')
optimizer_SGD = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_Adagrad = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_Adam = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_Adamax = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_ASGD = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_LBFGS = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_RMSprop = torch.optim.SGD(model.parameters(),lr=0.01)
optimizer_Rprop = torch.optim.SGD(model.parameters(),lr=0.01)epoch_list = []# optimizer_SGD
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)epoch_list.append(epoch)loss_SGD.append(loss.data)optimizer_SGD.zero_grad()loss.backward()optimizer_SGD.step()# optimizer_Adagrad
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_Adagrad.append(loss.data)optimizer_Adagrad.zero_grad()loss.backward()optimizer_Adagrad.step()# optimizer_Adam
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_Adam.append(loss.data)optimizer_Adam.zero_grad()loss.backward()optimizer_Adam.step()# optimizer_Adamax
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_Adamax.append(loss.data)optimizer_Adamax.zero_grad()loss.backward()optimizer_Adamax.step()# optimizer_ASGD
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_ASGD.append(loss.data)optimizer_ASGD.zero_grad()loss.backward()optimizer_ASGD.step()# optimizer_LBFGS
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_LBFGS.append(loss.data)optimizer_LBFGS.zero_grad()loss.backward()optimizer_LBFGS.step()# optimizer_RMSprop
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_RMSprop.append(loss.data)optimizer_RMSprop.zero_grad()loss.backward()optimizer_RMSprop.step()# optimizer_Rprop
for epoch in range(100):y_pred = model(x_data)loss = criterion(y_pred,y_data)loss_Rprop.append(loss.data)optimizer_Rprop.zero_grad()loss.backward()optimizer_Rprop.step()x_test = torch.tensor([4.0])
y_test = model(x_test)print('y_pred = ', y_test.data)plt.subplot(241)
plt.title("SGD")
plt.plot(epoch_list,loss_SGD)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(242)
plt.title("Adagrad")
plt.plot(epoch_list,loss_Adagrad)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(243)
plt.title("Adam")
plt.plot(epoch_list,loss_Adam)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(244)
plt.title("Adamax")
plt.plot(epoch_list,loss_Adamax)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(245)
plt.title("ASGD")
plt.plot(epoch_list,loss_ASGD)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(246)
plt.title("LBFGS")
plt.plot(epoch_list,loss_LBFGS)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(247)
plt.title("RMSprop")
plt.plot(epoch_list,loss_RMSprop)
plt.ylabel('cost')
plt.xlabel('epoch')plt.subplot(248)
plt.title("Rprop")
plt.plot(epoch_list,loss_Rprop)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()

运行结果:

http://www.lryc.cn/news/228602.html

相关文章:

  • 学习网络编程No.9【应用层协议之HTTPS】
  • PSP - 蛋白质复合物结构预测 Template Pair 特征 Mask 可视化
  • RK3568开发笔记-amixer开机设置音量异常
  • STM32两轮平衡小车原理详解(开源)
  • 区间内的真素数问题(C#)
  • eclipse安装lombok插件
  • 故障演练 | 微服务架构下如何做好故障演练
  • Python爬虫-获取汽车之家车家号
  • No195.精选前端面试题,享受每天的挑战和学习
  • pytest与testNg自动化框架
  • 数据库安全:Hadoop 未授权访问-命令执行漏洞.
  • 前端---认识HTML
  • 竞赛 题目:基于FP-Growth的新闻挖掘算法系统的设计与实现
  • 保姆级jupyter lab配置清单
  • 数据结构预算法--链表(单链表,双向链表)
  • 数据结构线性表——栈
  • 自定义 springboot 启动器 starter 与自动装配原理
  • 16 _ 二分查找(下):如何快速定位IP对应的省份地址?
  • vb.net圣经带快捷键,用原装的数据库
  • Unity中Shader的雾效
  • 企业微信开发教程一:添加企微应用流程图解以及常见问题图文说明
  • 【LeetCode】67. 二进制求和
  • 【LeetCode刷题笔记】二叉树(一)
  • NativeScript开发ios应用,怎么生成测试程序?
  • Js面试题:说一下js的模块化?
  • 媒体转码软件Media Encoder 2024 mac中文版功能介绍
  • 整治PPOCRLabel中cv2文件读取问题(更新中)
  • 网络运维Day09-补充
  • 【C++】【Opencv】minMaxLoc()函数详解和示例
  • 用Go实现网络流量解析和行为检测引擎