当前位置: 首页 > news >正文

神经网络基础-神经网络补充概念-43-梯度下降法

概念

梯度下降法(Gradient Descent)是一种优化算法,用于在机器学习和深度学习中最小化(或最大化)目标函数。它通过迭代地调整模型参数,沿着梯度方向更新参数,以逐步接近目标函数的最优解。梯度下降法在训练神经网络等机器学习模型时非常常用,可以帮助模型学习数据中的模式和特征。

基本原理和步骤

目标函数定义:首先,需要定义一个目标函数(损失函数),它用来衡量模型预测与实际值之间的差异。通常目标是最小化损失函数。

参数初始化:初始化模型的参数,这些参数将在优化过程中被逐步调整。

计算梯度:计算损失函数对于模型参数的梯度(导数)。梯度表示了目标函数在当前参数值处的变化率,它指示了在哪个方向上参数应该更新以减小损失。

参数更新:通过梯度下降公式,沿着梯度的反方向更新模型的参数。更新步长由学习率(learning rate)控制,学习率越大,参数更新越大;学习率越小,参数更新越小。

重复迭代:重复执行步骤 3 和 4,直到达到预定的迭代次数(epochs)或收敛条件。通常,随着迭代次数的增加,模型的损失逐渐减小,参数逐渐趋于收敛到最优值。

梯度下降法可以分为多种变体,包括批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-Batch Gradient Descent)。随机梯度下降和小批量梯度下降在实际应用中更为常见,因为它们可以更快地收敛并适应大规模数据。

代码实现(SGD)

import numpy as np
import matplotlib.pyplot as plt# 生成一些随机数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 初始化参数
theta = np.random.randn(2, 1)# 学习率
learning_rate = 0.01# 迭代次数
n_iterations = 1000# 随机梯度下降
for iteration in range(n_iterations):random_index = np.random.randint(100)xi = X_b[random_index:random_index+1]yi = y[random_index:random_index+1]gradients = 2 * xi.T.dot(xi.dot(theta) - yi)theta = theta - learning_rate * gradients# 绘制数据和拟合直线
plt.scatter(X, y)
plt.plot(X, X_b.dot(theta), color='red')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression with Stochastic Gradient Descent')
plt.show()print("Intercept (theta0):", theta[0][0])
print("Slope (theta1):", theta[1][0])
http://www.lryc.cn/news/133406.html

相关文章:

  • Reids之Set类型解读
  • 【网络基础】数据链路层
  • 云计算|OpenStack|使用VMware安装华为云的R006版CNA和VRM---初步使用(二)
  • Python typing函式庫和torch.types
  • UE5 编程规范
  • 交互消息式IMessage扩展开发记录
  • 软件团队降本增效-建立需求评估体系
  • npm yarn pnpm 命令集
  • python 开发环境(PyCharm)搭建指南
  • springboot里 运用 easyexcel 导出
  • 一“码”当先,PR大征集!2023 和RT-Thread一起赋能开源!
  • jmeter模拟多用户并发
  • 澎峰科技|邀您关注2023 RISC-V中国峰会!
  • 【系统架构】系统架构设计之数据同步策略
  • Linux内核学习笔记——ACPI命名空间
  • 使用 OpenCV Python 实现自动图像注释工具的详细步骤--附完整源码
  • RunnerGo中WebSocket、Dubbo、TCP/IP三种协议接口测试详解
  • 【Java 动态数据统计图】动态数据统计思路案例(动态,排序,数组)一(112)
  • kafka踩坑
  • 让你专注于工作的电脑桌面日程提醒软件
  • 62页智慧产业园区数字化综合解决方案PPT
  • 苹果开发者账号注册方法简明指南
  • SQL-每日一题【1321. 餐馆营业额变化增长】
  • PyCharm PyQt5 开发环境搭建
  • 2023-08-17 Untiy进阶 C#知识补充8——C#中的日期与时间
  • SPSS--如何使用分层分析以及分层分析案例分享
  • 时序数据库influxdb笔记
  • 8月18日上课内容 Haproxy搭建Web群集
  • 【高阶数据结构】红黑树详解
  • 树莓牌4B安装Centos8