当前位置: 首页 > news >正文

线形回归与小批量梯度下降实例

1、准备数据集

import numpy as np
import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader
from torch.utils.data import TensorDataset#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100)     #使用random.seed,设置一个固定的随机种子,
data_size = 150         # 数据集大小
x_range = 5             # x的范围
iteration_count = 100   # 迭代次数# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3plt.scatter(x,y,marker='x',color='green')#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,batch_size = 20, #每一个小批量的数据规模是20shuffle =True )  #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))for index,(data,label) in enumerate(dataloader):print("index=%d num = %d"%(index,len(data)))

2、线性回归模型的训练思路

2.1 初始化参数

设置是模型参数:权重w和偏置b

初始化为随机值

设置 requires_grad=True,PyTorch 将记录这些张量的操作历史,用于后续的自动求导

2.2 循环训练(epoch

epoch 变量用于控制整个训练过程的迭代轮数

在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。

定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。

2.3 数据加载

内层循环通过 dataloader 遍历训练数据集的小批量数据。dataloader 是一个数据加载器,通常由 DataLoader 类创建,用于批量加载数据。

2.4 前向传播

假设 tensorX 是当前批次的数据

tensorY 是对应的真实标签

使用当前参数 w 和 b 计算预测值 h=w*tensorX +b。

2.5 计算损失

计算预测值 h 和真实值 tensorY 之间的均方误差(MSE),并保存到 loss

loss = torch.mean((h - tensorY) ** 2)

2.6 反向传播

调用 loss.backward() 进行反向传播,计算损失关于参数 w 和 b 的梯度

设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导

2.7 更新参数

使用梯度下降算法更新参数 w 和 b。学习率设置为0.01

w.data -= 0.01 * w.grad.data

b.data -= 0.01 * b.grad.data

沿着当前小批量计算的得到的梯度(导数)更新w和b

如果导数为0,则w、b保存不变

2.8 梯度清零

在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算

3、线性回归模型的实现

# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数# 在一个迭代轮次中,以小批量的方式,使用dataloader对数据# batch_index表示当前遍历的批次# data和label表示这个批次的训练数据和标记for batch_index,(data, label)in enumerate(dataloader):h = tensorX * w + b #计算当前直线的预测值,保存到h#计算预测值h和真实值y之间的均方误差,保存到loss中loss=torch.mean((h-tensorY)**2)#计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导loss.backward()#进行梯度下降,沿着梯度的反方向,更新w和b的值#沿着当前小批量计算的得到的梯度(导数)更新w和b#如果导数为0(Δw,Δb为0),则w、b保存不变w.data -=0.01 * w.grad.datab.data -=0.01 * b.grad.dataprint("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))#清空张量w和b中的梯度信息,为下一次迭代做准备w.grad.zero_()b.grad.zero_()#每次迭代,都打印当前迭代的轮数epoch#数据的批次batch idx和loss损失值

http://www.lryc.cn/news/520301.html

相关文章:

  • SpringCloud微服务:基于Nacos组件,整合Dubbo框架
  • Golang 简要概述
  • web前端第三次作业---制作可提交的用户注册表
  • 教育邮箱的魔力:免费获取Adobe和JetBrains软件
  • sympy常用函数与错误笔记
  • 47_Lua文件IO操作
  • nginx-lua模块处理流程
  • 【大数据】机器学习-----最开始的引路
  • 【前端】自学基础算法 -- 21.图的广度优先搜索
  • ChatGPT与Claude AI:两大生成式对话模型的比较分析
  • 前端开发:盒子模型、块元素
  • 升级 CentOS 7.x 系统内核到 4.4 版本
  • 播放音频文件同步音频文本
  • springboot使用Easy Excel导出列表数据为Excel
  • day07_Spark SQL
  • 高性能现代PHP全栈框架 Spiral
  • LeetCode - #182 Swift 实现找出重复的电子邮件
  • 《解锁鸿蒙Next系统人工智能语音助手开发的关键步骤》
  • 【Linux网络编程】数据链路层 | MAC帧 | ARP协议
  • 《自动驾驶与机器人中的SLAM技术》ch7:基于 ESKF 的松耦合 LIO 系统
  • 基于spingbott+html+Thymeleaf的24小时智能服务器监控平台设计与实现
  • 全栈面试(一)Basic/微服务
  • python安装完成后可以进行的后续步骤和注意事项
  • [Qt] 窗口 | 菜单栏MenuBar
  • [读书日志]从零开始学习Chisel 第十三篇:Scala的隐式参数与隐式转换(敏捷硬件开发语言Chisel与数字系统设计)
  • CMake学习笔记(1)
  • cursor+deepseek构建自己的AI编程助手
  • Kotlin实现DataBinding结合ViewModel的时候,提示找不到Unresolved reference: BR解决方案
  • java项目启动时,执行某方法
  • 详解如何自定义 Android Dex VMP 保护壳