当前位置: 首页 > news >正文

百度飞浆:paddle 线性回归模型

学习引用 参考视频:
https://www.bilibili.com/video/BV1oRtkeVEVx?spm_id_from=333.788.player.switch&vd_source=c7739de98d044e74cdc74d6e772bed5f&p=2

这段代码使用PaddlePaddle深度学习框架来实现一个简单的线性回归模型,旨在从给定的出租车行驶公里数和对应的支付费用中学习出租车的起步价和每公里行驶费用。下面我将逐行解释这段代码的功能:

  1. 导入数据

    x_data = paddle.to_tensor([[1.0], [3.0], [5.0], [9.0], [20.0]])
    y_data = paddle.to_tensor([[12.0],[16.0],[20.0],[28.0],[50.0]])
    

    这里,x_data表示行驶公里数,y_data表示对应的支付费用。它们都被转换为PaddlePaddle的张量(Tensor)格式,以便后续的计算。

  2. 定义线性模型

    linear = paddle.nn.Linear(in_features=1, out_features=1)
    

    定义一个线性模型(也称为全连接层或密集层),输入特征数为1(即公里数),输出特征数为1(即预测的费用)。

  3. 查看初始权重和偏置

    w_before_opt = linear.weight.numpy().item()
    b_before_opt = linear.bias.numpy().item()
    print(w_before_opt, b_before_opt)
    

    打印出模型初始化的权重和偏置值。这些值是随机初始化的。

  4. 定义损失函数和优化器

    mse_loss = paddle.nn.MSELoss()
    sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
    

    使用均方误差(MSE)作为损失函数,因为这是一个回归问题。选择随机梯度下降(SGD)作为优化器,并设置学习率为0.001。

  5. 训练循环

    total_epoch = 5000
    for i in range(total_epoch):y_predict = linear(x_data)loss = mse_loss(y_predict, y_data)loss.backward()sgd_optimizer.step()sgd_optimizer.clear_gradients()
    

    进行5000次迭代(或称为epoch)。在每次迭代中,首先计算预测值y_predict,然后计算损失值loss,接着通过loss.backward()计算梯度,sgd_optimizer.step()更新模型参数,最后通过sgd_optimizer.clear_gradients()清除梯度,为下一次迭代做准备。

  6. 每1000次迭代打印一次损失

    if i % 1000 == 0:print(i, loss.numpy())
    

    为了监控训练过程,每1000次迭代打印一次当前的损失值。

  7. 训练结束后的操作和打印

    print("finish training, loss = {}".format(loss.numpy()))
    w_after_opt = linear.weight.numpy().item()
    b_after_opt = linear.bias.numpy().item()
    print(w_after_opt, b_after_opt)
    

    打印出训练结束后的最终损失值,以及优化后的权重和偏置值。这些值代表了学习到的起步价(偏置)和每公里费用(权重)。

总结
这段代码通过线性回归模型,从给定的出租车行驶公里数和支付费用数据中学习出租车的起步价和每公里行驶费用。通过多次迭代,模型逐渐调整其权重和偏置,以最小化预测费用与实际费用之间的均方误差。最终,模型学习到的权重和偏置值可以被解释为出租车的每公里费用和起步价。


```python
import paddle
# 任务乘坐出租车起步价10元,每公里2元
def calculate_fee(distance_travelled):return 10 + 2 * distance_travelledfor x in [1.0, 3.0, 5.0, 9.0, 20.0]:print(calculate_fee(x))#知道乘客每次乘坐出租车公里数,也知道乘客每次下车支付费用
#求 起步价、以及每公里形式费用。目标让机器从这些数据当中学习出来计算费用的规则
x_data = paddle.to_tensor([[1.0], [3.0], [5.0], [9.0], [20.0]])
y_data = paddle.to_tensor([[12.0],[16.0],[20.0],[28.0],[50.0]])linear = paddle.nn.Linear(in_features=1, out_features=1)
w_before_opt = linear.weight.numpy().item()
b_before_opt = linear.bias.numpy().item()
print(w_before_opt, b_before_opt)mse_loss = paddle.nn.MSELoss()
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())total_epoch = 5000
for i in range(total_epoch):y_predict = linear(x_data)loss = mse_loss(y_predict, y_data)loss.backward()sgd_optimizer.step()sgd_optimizer.clear_gradients()if i % 1000 == 0:print(i, loss.numpy())print("finish training, loss = {}".format(loss.numpy()))w_after_opt = linear.weight.numpy().item()
b_after_opt = linear.bias.numpy().item()
print(w_after_opt, b_after_opt)

http://www.lryc.cn/news/488028.html

相关文章:

  • 【JavaSE】【网络编程】UDP数据报套接字编程
  • 45.坑王驾到第九期:Mac安装typescript后tsc命令无效的问题
  • 20241120-Milvus向量数据库快速体验
  • 【Golang】——Gin 框架中间件详解:从基础到实战
  • 量子计算来袭:如何保护未来的数字世界
  • VMware虚拟机(Ubuntu或centOS)共享宿主机网络资源
  • 光伏电站仿真系统的作用
  • Golang文件操作
  • 爬虫开发工具与环境搭建——使用Postman和浏览器开发者工具
  • React(二)
  • 同步原语(Synchronization Primitives)
  • SpringBoot服务多环境配置
  • STM32单片机CAN总线汽车线路通断检测-分享
  • 【环境搭建】使用IDEA远程调试Docker中的Java Web
  • 贴代码框架PasteForm特性介绍之select,selects,lselect和reload
  • STM32G4的数模转换器(DAC)的应用
  • SpringMVC跨线程获取requests请求对象(子线程共享servletRequestAttributes)和跨线程获取token信息
  • 提取repo的仓库和工作树(无效)
  • 力扣整理版七:二叉树(待更新)
  • 基于单片机的多功能环保宠物窝设计
  • HBase 基础操作
  • 小米顾此失彼:汽车毛利大增,手机却跌至低谷
  • PCL 三维重建 a-shape曲面重建算法
  • 【Android】线程池的解析
  • 集群聊天服务器(8)用户登录业务
  • Go语言中的错误嵌套
  • 51单片机基础 06 串口通信与串口中断
  • Elasticsearch:更好的二进制量化(BBQ)对比乘积量化(PQ)
  • 【GNU】gcc -g编译选项 -g0 -g1 -g2 -g3 -gdwarf
  • MySQL【六】