当前位置: 首页 > news >正文

pytorch简单新型模型测试参数

import torch
from torch.nn import Conv2d,MaxPool2d,Sequential,Flatten,Linear
import torchvision
import torch.optim.optimizer
from torch.utils.data import DataLoader,dataset
from torch import nn
import torch.optim.optimizer# 建模
model = nn.Linear(2,1)#损失
loss = nn.MSELoss()
#优化
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)#定义输入和标签
input = torch.tensor([[2.,7.],[1.,6.]])
y = torch.tensor([[1.],[3.]])#输入模型数据
out= model(input)
print(out)
#计算损失
loss_fn = loss(y,out)
print(loss_fn.item())
#梯度清零
optimizer.zero_grad()
#反向传播
loss_fn.backward()
print(loss_fn.item())
#更新梯度
optimizer.step()# 再次进行前向传播和反向传播
x = torch.tensor([[5., 6.], [7., 8.]])
y_true = torch.tensor([[11.], [15.]])
y_pred = model(x)
loss = loss_fn(y_pred, y_true)
optimizer.zero_grad()
loss.backward()
optimizer.step()'''
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)# SGD 就是随机梯度下降
opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
# momentum 动量加速,在SGD函数里指定momentum的值即可
opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
# RMSprop 指定参数alpha
opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
# Adam 参数betas=(0.9, 0.99)
opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))计算损失
w=w−l_r*dw
b=b-l_r*dbdw和db分别是权重和偏置的梯度,learning_rate是学习率,控制每次更新的步长
'''def  hook():# 定义模型参数w = torch.tensor([1.0], requires_grad=True) #requires_grad=True 的作用是让 backward 可以追踪这个参数并且计算它的梯度。b = torch.tensor([0.0], requires_grad=True) ##requires_grad=True 的作用是让 backward 可以追踪这个参数并且计算它的梯度。# 定义输入和目标输出x = torch.tensor([2.0])y_true = torch.tensor([4.0])# 定义损失函数loss_fn = torch.nn.MSELoss()# 定义优化器optimizer = torch.optim.SGD([w, b], lr=0.1)# 迭代训练for i in range(100):# 前向传播y_pred = w * x + bloss = loss_fn(y_pred, y_true)# 反向传播optimizer.zero_grad()loss.backward()# 提取梯度  我们使用loss.backward()计算损失函数对于模型参数的梯度,并将其保存在相应的张量的.grad属性中dw = w.graddb = b.gradprint("dw".format(dw))print("db".format(db))# 更新模型参数optimizer.step()# 输出模型参数print("w = ", w)print("b = ", b)

记录一些api:

表3-1: 常见新建tensor的方法

函数功能
Tensor(*sizes)基础构造函数
tensor(data,)类似np.array的构造函数
ones(*sizes)全1Tensor
zeros(*sizes)全0Tensor
eye(*sizes)对角线为1,其他为0
arange(s,e,step从s到e,步长为step
linspace(s,e,steps)从s到e,均匀切分成steps份
rand/randn(*sizes)均匀/标准分布
normal(mean,std)/uniform(from,to)正态分布/均匀分布
randperm(m)随机排列

http://www.lryc.cn/news/306426.html

相关文章:

  • Unity中URP下实现水体(水面高光)
  • 26.HarmonyOS App(JAVA)列表对话框
  • 五种主流数据库:常用字符函数
  • 软考笔记--企业资源规划和实施
  • React歌词滚动效果(跟随音乐播放时间滚动)
  • java面试题之mybatis篇
  • Java的编程之旅19——使用idea对面相对象编程项目的创建
  • docker build基本命令
  • nginx高级配置详解
  • 小程序--分包加载
  • R语言【base】——writeLines()
  • 微信小程序-人脸检测
  • 微信小程序自制动态导航栏
  • 金融知识分享系列之:五日线
  • 回归测试详解
  • 渲染效果图有哪几种分类?效果图为什么用云渲染更快
  • Docker镜像加速
  • 吴恩达deeplearning.ai:sigmoid函数的替代方案以及激活函数的选择
  • Alias许可分析中的数据可视化
  • 【计算机网络】数据链路层--以太网/MTU/ARP/RARP协议
  • typescript使用解构传参
  • CSP-J 2023 复赛第4题:旅游巴士
  • JAVA算法和数据结构
  • 每日五道java面试题之spring篇(七)
  • Keil编译GD32工程时找不到lib库文件
  • 测试C#使用ViewFaceCore实现图片中的人脸遮挡
  • 2.21 Qt day2 菜单栏/工具栏/状态栏/浮动窗口、UI界面、信号与槽
  • 300分钟吃透分布式缓存-16讲:常用的缓存组件Redis是如何运行的?
  • 上一篇文章补充:已经存在的小文件合并
  • 代码随想录训练营第三十期|第四十三天|动态规划 part05|1049. 最后一块石头的重量 II ● 494. 目标和 ● 474.一和零