当前位置: 首页 > news >正文

Pytorch线性回归实现(Pycharm实现)

步骤都在注释里写清楚了,可以自己调整循环的次数观察输出的w与b和loss的值

import torch#学习率,用来进行w和b的更新
learning_rate = 0.01
#1. 准备数据
#这里使用y=3x+0.8.也就是w=3,b=0.8.创造一个500行1列的数据
x=torch.rand([500,1])
y_true=x*0.3+0.8#2. 通过模型计算y_predict。x*w,所以w是1行1列的.torch.matmul是矩阵乘法.只有浮点数才能使用grad。修改dtype
w = torch.rand([1,1],requires_grad=True)
b = torch.tensor(0,requires_grad=True,dtype=torch.float32)#4. 通过循环,反向传播,更新参数
for i in range(5000):y_predict = torch.matmul(x, w) + b# 3. 计算loss.用平方来处理,这里mean不太清楚是什么意思。均方误差?这是什么?....每次都需要更新损失,所以把他放在循环里loss = (y_true - y_predict).pow(2).mean()#每次backward之前梯度置为0if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()loss.backward() #反向传播.这时w和b的梯度就算出来了w.grad,b.gradw.data = w.data - learning_rate * w.gradb.data = b.data - learning_rate * b.grad  #要注意左边不要写成grad,写成grad之后b的内容就一直是0print("w,b,loss",w.item(),b.item(),loss.item())

输出:

可以观察到w接近0.3,b接近0.8。和预想值十分接近了。

问题:

这里的理解有欠缺。。。

http://www.lryc.cn/news/313424.html

相关文章:

  • 2024新疆专升本考试报名教程详解
  • unicloud 云数据库概念及创建一个云数据库表并添加记录(数据)
  • 想交易盈利?Anzo Capital昂首资本发现了一本畅销书
  • 美国站群服务器租用需要考虑哪些关键点
  • 如何构建Hive数据仓库Hive 、数据仓库的存储方式 以及hive数据的导入导出
  • 【Linux】软件管理器yum和编辑器vim
  • 怎么才能确定螺栓是拧紧了——SunTorque智能扭矩系统
  • 西门子S120故障报警F30003的解决办法总结
  • 探索vue框架的世界: 内部、外部样式和内联样式动态绑定的方法
  • 代码随想录算法训练营第三十八天|动态规划|理论基础、509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯
  • 运维知识点-JBoss
  • HarmonyOS—配置编译构建信息
  • Chrome浏览器好用的几个扩展程序
  • Enzo Life Sciences Cortisol(皮质醇) ELISA kit
  • 面试经典150题 -- 二分查找 (总结)
  • 蓝牙耳机怎么选择比较好?2024年热门机型推荐大揭秘!
  • 强制Unity崩溃的两个方法
  • 中间件 | Redis - [big-key hot-key]
  • STM32基础--自己构建库函数
  • 网站被插入虚假恶意链接怎么办?
  • ThreeJs限制模型拖动的范围
  • 关于JVM的小总结(待补充)
  • day37 贪心算法part6
  • 38女神节:剧情热梗小游戏新品!预售1折秒杀,手慢无
  • 岩土工程监测仪器振弦采集仪的发展历程与国内外研究现状
  • Git 掌握
  • 面试题之——事务失效的八大情况
  • 一些硬件知识(六)
  • 前端React篇之哪些方法会触发 React 重新渲染?重新渲染 render 会做些什么?
  • PHP伪协议是什么?