当前位置: 首页 > article >正文

图解深度学习 - 基于梯度的优化(梯度下降)

在模型优化过程中,我们曾尝试通过手动调整单个标量系数来观察其对损失值的影响。具体来说,当初始系数为0.3时,损失值为0.5。随后,我们尝试增加系数至0.35,发现损失值上升至0.6;相反,当系数减小至0.25时,损失值下降至0.4。这一实验结果表明,在该特定情境下,减小系数值有助于降低模型的损失值。

然而,这种方法在实际应用中非常低效,因为模型通常包含大量的系数(可能达到上千个甚至上百万个),对每个系数进行两次前向传播来计算不同取值下的损失值,计算成本极高。

为了解决这个问题,引入了梯度下降法作为一种更高效的优化方法。梯度下降法通过计算损失函数对每个系数的梯度(即损失值对系数的导数),能够指导我们如何调整每个系数以最小化损失值,而无需对每个系数进行多次前向传播试验。

图片


资料分享

为了方便大家学习,我整理了一份深度学习资料+80G人工智能资料包(如下图)

不仅有入门级教程,配套课件,还有进阶实战,源码数据集,更有面试题帮你提升~

需要的兄弟可以按照这个图的方式免费获取


一、梯度下降

梯度下降(Gradient Descent是什么梯度下降是一种通过迭代计算损失函数梯度并沿其反方向更新参数以最小化损失值的优化算法。

梯度下降法基于这样一个观察:如果一个函数在某点处可微且有定义,那么函数在该点沿着梯度的反方向下降最快。因此,算法从初始估计的参数点开始,通过计算损失函数的梯度,并沿着梯度的反方向进行迭代搜索,逐步接近函数的局部极小值。

  1. 初始化参数:选择一个起始点作为初始参数,这些参数可以是任意值或随机选择的值。

  2. 计算梯度:计算当前参数点处的损失函数的梯度。梯度是一个向量,表示损失函数在每个参数维度上的变化率。

  3. 更新参数:使用梯度信息来更新参数,以使损失函数的值减小。这通常是通过沿着梯度的反方向进行调整来实现的,调整的大小由学习率决定

  4. 迭代更新:重复计算梯度和更新参数的步骤,直到满足停止准则,如达到预设的最大迭代次数或损失函数值减小到足够小的值。

图片

“一图 + 一句话”彻底搞懂梯度下降。

“梯度下降是一种通过迭代计算损失函数梯度并沿其反方向调整参数,以最小化损失值的优化算法,它避免了手动调整每个参数的繁琐和高昂计算成本

图片


二、BGD、SGD、MBGD

梯度下降算法有哪些批量梯度下降(BGD)利用全部数据计算梯度,收敛快但计算量大;随机梯度下降(SGD)每次仅使用一个样本,计算量小但收敛慢且可能震荡;小批量梯度下降(MBGD)则是两者的折中,选择部分样本计算梯度,既降低了计算量又保持了较快的收敛速度。

图片

1. 批量梯度下降(Batch Gradient Descent,BGD):在每次迭代中使用全部的训练数据来计算梯度,然后更新模型参数。

  • 优点:收敛速度相对较快,可以利用矩阵运算加速计算,且在凸优化问题中能保证收敛到全局最优解。

  • 缺点:在处理大规模数据集时,计算梯度的时间和空间复杂度较高,内存使用量可能过大。

2. 随机梯度下降(Stochastic Gradient Descent,SGD):在每次迭代中随机选择一个样本来计算梯度,然后更新模型参数。

  • 优点:计算梯度的时间和空间复杂度较低,适用于处理大规模数据集,且能跳出局部最优解(因为每次更新参数的方向不一定是相同的)。

  • 缺点:收敛速度较慢,且可能会出现震荡现象,对于稠密数据集的计算速度可能较慢。

3. 小批量梯度下降(Mini-Batch Gradient Descent):在每次迭代中选择一小部分样本来计算梯度,然后更新模型参数,是批量梯度下降和随机梯度下降的折中方案。

  • 优点:计算梯度的时间和空间复杂度较低,收敛速度较快,且可以利用矩阵运算的并行性加速计算,同时能跳出局部最优解。

  • 缺点:需要手动设置小批量大小,如果选择不当可能会影响收敛速度和精度。对于大规模、稀疏或实时数据流问题,其计算效率可能不如SGD,但比BGD要好。

图片

“一图 + 一句话”彻底搞懂BGD、SGD、MBGD。

“ 梯度下降算法主要包括批量梯度下降(BGD,利用全部数据,收敛快但计算量大)随机梯度下降(SGD,每次仅用一个样本,计算量小但收敛慢且可能震荡)小批量梯度下降(MBGD,部分样本折中方案,既降低计算量又保持较快收敛速度)

图片

http://www.lryc.cn/news/2393065.html

相关文章:

  • MySql--定义表存储引擎、字符集和排序规则
  • 【部署】在离线服务器的docker容器下升级dify-import程序
  • 优化版本,增加3D 视觉 查看前面的记录
  • 写作-- 复合句练习
  • WWW22-可解释推荐|用于推荐的神经符号描述性规则学习
  • Linux:shell脚本常用命令
  • 专业课复习笔记 11
  • OpenTelemetry × Elastic Observability 系列(一):整体架构介绍
  • STM32高级物联网通信之以太网通讯
  • 从Java的Jvm的角度解释一下为什么String不可变?
  • 从零开始的数据结构教程(四) ​​图论基础与算法实战​​
  • 历年西安交通大学计算机保研上机真题
  • 可视化与动画:构建沉浸式Vue应用的进阶实践
  • Python |GIF 解析与构建(3):简单哈希压缩256色算法
  • 蓝桥杯2114 李白打酒加强版
  • 基本数据指针的解读-C++
  • Android Studio里的BLE数据接收策略
  • 【Office】Excel两列数据比较方法总结
  • 基于多模态脑电、音频与视觉信号的情感识别算法【Nature核心期刊,EAV:EEG-音频-视频数据集】
  • 【QueryServer】dbeaver使用phoenix连接Hbase(轻客户端方式)
  • 数据湖 (特点+与数据仓库和数据沼泽的对比讲解)
  • 深入链表剖析:从原理到 C 语言实现,涵盖单向、双向及循环链表全解析
  • 编码总结如下
  • 《算力觉醒!ONNX Runtime + DirectML如何点燃Windows ARM设备的AI引擎》
  • [9-1] USART串口协议 江协科技学习笔记(13个知识点)
  • Oracle基础知识(五)——ROWID ROWNUM
  • 简述synchronized和java.util.concurrent.locks.Lock的异同 ?
  • OpenCV CUDA模块直方图计算------在 GPU 上计算图像直方图的函数calcHist()
  • EMS只是快递那个EMS吗?它跟能源有什么关系?
  • 日志技术-LogBack、Logback快速入门、Logback配置文件、Logback日志级别