当前位置: 首页 > news >正文

神经网络基础-神经网络补充概念-49-adam优化算法

概念

Adam(Adaptive Moment Estimation)是一种优化算法,结合了动量梯度下降法和RMSProp的优点,用于在训练神经网络等深度学习模型时自适应地调整学习率。Adam算法在深度学习中广泛应用,通常能够加速收敛并提高模型性能。

Adam算法综合了动量(momentum)和均方梯度的移动平均(RMSProp)来更新模型参数。与传统的梯度下降法不同,Adam维护了一个每个参数的动量变量和均方梯度的移动平均变量,并在每个迭代步骤中使用这些变量来调整学习率。

步骤

1初始化参数:初始化模型的参数。

2初始化动量变量和均方梯度的移动平均:初始化动量变量为零向量,初始化均方梯度的移动平均为零向量。

3计算梯度:计算当前位置的梯度。

4更新动量变量:计算动量变量的移动平均。

momentum = beta1 * momentum + (1 - beta1) * gradient

其中,beta1 是用于计算动量变量移动平均的超参数。
5更新均方梯度的移动平均:计算均方梯度的移动平均。

moving_average = beta2 * moving_average + (1 - beta2) * gradient^2

其中,beta2 是用于计算均方梯度的移动平均的超参数
6修正偏差
对动量变量和均方梯度的移动平均进行偏差修正,以减轻初始迭代的影响。

corrected_momentum = momentum / (1 - beta1^t)
corrected_moving_average = moving_average / (1 - beta2^t)

7更新参数

parameter = parameter - learning_rate * corrected_momentum / (sqrt(corrected_moving_average) + epsilon)

其中,epsilon 是一个小的常数,防止分母为零。

8重复迭代:重复执行步骤 3 到 7,直到达到预定的迭代次数(epochs)或收敛条件。

代码实现

import numpy as np
import matplotlib.pyplot as plt# 生成随机数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 初始化参数
theta = np.random.randn(2, 1)# 学习率
learning_rate = 0.1# Adam参数
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
momentum = np.zeros_like(theta)
moving_average = np.zeros_like(theta)# 迭代次数
n_iterations = 1000# Adam优化
for iteration in range(n_iterations):gradients = 2 / 100 * X_b.T.dot(X_b.dot(theta) - y)momentum = beta1 * momentum + (1 - beta1) * gradientsmoving_average = beta2 * moving_average + (1 - beta2) * gradients**2corrected_momentum = momentum / (1 - beta1**(iteration+1))corrected_moving_average = moving_average / (1 - beta2**(iteration+1))theta = theta - learning_rate * corrected_momentum / (np.sqrt(corrected_moving_average) + epsilon)# 绘制数据和拟合直线
plt.scatter(X, y)
plt.plot(X, X_b.dot(theta), color='red')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression
http://www.lryc.cn/news/128391.html

相关文章:

  • Java:正则表达式书写规则及相关案例:检验QQ号码,校验手机号码,邮箱格式,当前时间
  • 图数据库_Neo4j_Centos7.9安装Neo4j社区版3.5.4_基于jdk1.8---Neo4j图数据库工作笔记0011
  • 使用Rust编写的一款使用遗传算法、神经网络、WASM技术的模拟生物进化的程序
  • UE4/UE5 “无法双击打开.uproject 点击无反应“解决
  • 【前端】深入理解CSS定位
  • 【问题】分布式事务的场景下如何保证读写分离的数据一致性
  • 常见的Web安全漏洞有哪些,Web安全漏洞常用测试方法介绍
  • 随机微分方程
  • 下载安装并使用小乌龟TortoiseGit
  • npm ERR!Cannot read properties of null(reading ‘pickAlgorithm’)报错问题解决
  • web前端tips:js继承——组合继承
  • (7)(7.3) 自动任务中的相机控制
  • Python 爬虫小练
  • vue3 事件处理 @click
  • 【第三阶段】kotlin语言使用replace完成加解密操作
  • springBoot是如何实现自动装配的
  • 基于python+MobileNetV2算法模型实现一个图像识别分类系统
  • 管理类联考——逻辑——真题篇——按知识分类——汇总篇——二、论证逻辑——归纳评价——归纳谬误
  • C++适配器模式
  • cocos creator 设置精灵镜像翻转效果
  • kafka的位移
  • 大数据平台运维实训室建设方案
  • dll调用nodejs的回调函数
  • 网络安全--linux下Nginx安装以及docker验证标签漏洞
  • 多维时序 | MATLAB实现WOA-CNN-BiGRU-Attention多变量时间序列预测
  • 金蝶软件实现Excel数据复制分录信息粘贴到单据体分录行中
  • 【Linux操作系统】深入探索Linux进程:创建、共享与管理
  • 【云原生、k8s】Calico网络策略
  • Unity3D 测试总结
  • 【无线点对点网络时延分析和可视化】模拟无线点对点网络中的延迟以及物理层和数据链路层之间的相互作用(Matlab代码实现)