当前位置: 首页 > news >正文

深度学习之用PyTorch实现线性回归

代码

# 调用库
import torch# 数据准备
x_data = torch.Tensor([[1.0], [2.0], [3.0]])  # 训练集输入值
y_data = torch.Tensor([[2.0], [4.0], [6.0]])  # 训练集输出值# 定义线性回归模型
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()  # 调用父类构造函数self.linear = torch.nn.Linear(1, 1)  # 实例化torch库nn模块的Linear类def forward(self, x):"""前馈运算:param x: 输入值:return: 线性回归预测结果"""y_pred = self.linear(x)return y_predmodel = LinearModel()  # 实例化LinearModel类criterion = torch.nn.MSELoss(size_average=False)  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化器——梯度下降SGD
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 优化器——Adam
# optimizer = torch.optim.Adamax(model.parameters(), lr=0.01)  # 优化器——Adamax# 训练过程
for epoch in range(1000):  # epoch:训练轮次y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()  # 梯度归零loss.backward()  # 反向传播optimizer.step()  # 权重自动更新print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())# 预测过程
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print("y_pred = ", y_test.data)

结果

1 不同epoch结果

1.1 epoch = 100时

1.2 epoch = 1000时

 

 2 不同优化器

2.1 Adam优化器

 

 2.2 Adamax优化器 

 

3 不同学习率(梯度下降)

3.1 lr = 0.05

 3.2 lr = 0.1(loss函数结果发散)

遇见的问题

1 代码问题(已解决)

1.1 问题

 1.2 解决办法

 2 关于神经网络

代码中model.parameters()函数保存的是Weights和Bais参数的值。但是对于其他网络(非线性)来说这个函数可以用吗,里面也是保存的w和b吗?

http://www.lryc.cn/news/105248.html

相关文章:

  • 45.248.11.X服务器防火墙是什么,具有什么作用
  • 如何以无服务器方式运行 Go 应用程序
  • 小程序商城系统的开发方式及优缺点分析
  • [数据集][目标检测]城市道路井盖破损丢失目标检测1377张
  • 【Spring Cloud 三】Eureka服务注册与服务发现
  • WPF实战学习笔记21-自定义首页添加对话服务
  • AngularJS学习(一)
  • 918. 环形子数组的最大和
  • AI算法图形化编程加持|OPT(奥普特)智能相机轻松适应各类检测任务
  • C语言文件指针设置偏移量--fseek
  • 快速消除视频的原声的技巧分享
  • lua脚本实现Redis令牌桶限流
  • 最新 23 届计算机校招薪资汇总
  • BUU CODE REVIEW 1
  • django使用ztree实现树状结构效果,子节点实现动态加载(l懒加载)
  • 认识springboot 之 了解它的日志 -4
  • 关于大规模数据处理的解决方案
  • 免费快速下载省市区县行政区的Shp数据
  • MAC下配置android-sdk
  • Hive-数据倾斜
  • Java多线程(三)
  • Linux操作系统3-项目部署
  • 软件测试面试题——接口自动化测试怎么做?
  • 如何在医疗器械行业运用IPD?
  • 16. Spring Boot 统一功能处理
  • PostgreSQL-数据库命令
  • 面试题:说说JavaScript中内存泄漏的几种情况?垃圾回收机制
  • HTML基础介绍1
  • 【腾讯云 Cloud Studio 实战训练营】Redisgo_task 分布式锁实现
  • Linux CentOS系统怎么下载软件