当前位置: 首页 > news >正文

BP神经网络的反向传播算法

BP神经网络(Backpropagation Neural Network)是一种常用的多层前馈神经网络,通过反向传播算法进行训练。反向传播算法的核心思想是通过计算损失函数对每个权重的偏导数,从而调整权重,使得网络的预测输出与真实输出之间的误差最小。下面是反向传播算法的公式推导过程:

1. 前向传播(Forward Propagation)

假设我们有一个三层神经网络(输入层、隐藏层和输出层),并且每层的激活函数为 sigmoid 函数。

- 输入层:\mathbf{x} = (x_1, x_2, \ldots, x_n)
- 隐藏层:\mathbf{h} = (h_1, h_2, \ldots, h_m)
- 输出层:\mathbf{y} = (y_1, y_2, \ldots, y_k)

各层之间的权重分别为:
- 输入层到隐藏层的权重:\mathbf{W}^{(1)}
- 隐藏层到输出层的权重:\mathbf{W}^{(2)}

对于第 j 个隐藏层神经元,其输入为:

z_j^{(1)} = \sum_{i=1}^n W_{ji}^{(1)} x_i + b_j^{(1)}

其输出为:

h_j = \sigma(z_j^{(1)})

对于第 l 个输出层神经元,其输入为:

z_l^{(2)} = \sum_{j=1}^m W_{lj}^{(2)} h_j + b_l^{(2)}

其输出为:

y_l = \sigma(z_l^{(2)})

其中,\sigma(z) 是激活函数(sigmoid 函数):

\sigma(z) = \frac{1}{1 + e^{-z}}

2. 计算损失函数(Loss Function)

假设损失函数为均方误差(MSE):

L = \frac{1}{2} \sum_{l=1}^k (y_l - \hat{y}_l)^2

其中,\hat{y}_l 是网络的预测输出,y_l 是真实输出。

 3. 反向传播(Backpropagation)

反向传播的目标是计算损失函数对每个权重的偏导数,并根据梯度下降法更新权重。

3.1 输出层的误差项

首先计算输出层的误差项:

\delta_l^{(2)} = \frac{\partial L}{\partial z_l^{(2)}} = \frac{\partial L}{\partial \hat{y}_l} \cdot \frac{\partial \hat{y}_l}{\partial z_l^{(2)}}

由于:

\frac{\partial L}{\partial \hat{y}_l} = \hat{y}_l - y_l
\frac{\partial \hat{y}_l}{\partial z_l^{(2)}} = \hat{y}_l (1 - \hat{y}_l)

所以:

\delta_l^{(2)} = (\hat{y}_l - y_l) \hat{y}_l (1 - \hat{y}_l)

3.2 隐藏层的误差项

接下来计算隐藏层的误差项:

\delta_j^{(1)} = \frac{\partial L}{\partial z_j^{(1)}} = \sum_{l=1}^k \frac{\partial L}{\partial z_l^{(2)}} \cdot \frac{\partial z_l^{(2)}}{\partial h_j} \cdot \frac{\partial h_j}{\partial z_j^{(1)}}

其中:

\frac{\partial z_l^{(2)}}{\partial h_j} = W_{lj}^{(2)}
\frac{\partial h_j}{\partial z_j^{(1)}} = h_j (1 - h_j)

所以:

\delta_j^{(1)} = \left( \sum_{l=1}^k \delta_l^{(2)} W_{lj}^{(2)} \right) h_j (1 - h_j)

3.3 更新权重

根据梯度下降法更新权重:

W_{lj}^{(2)} \leftarrow W_{lj}^{(2)} - \eta \frac{\partial L}{\partial W_{lj}^{(2)}} = W_{lj}^{(2)} - \eta \delta_l^{(2)} h_j
W_{ji}^{(1)} \leftarrow W_{ji}^{(1)} - \eta \frac{\partial L}{\partial W_{ji}^{(1)}} = W_{ji}^{(1)} - \eta \delta_j^{(1)} x_i

其中,\eta 是学习率。

http://www.lryc.cn/news/514860.html

相关文章:

  • [实用指南]如何将视频从iPhone传输到iPad
  • Linux Snipaste 截图闪屏/闪烁
  • 【YOLOv5】源码(common.py)
  • Node 如何生成 RSA 公钥私钥对
  • 瑞_Linux中部署配置Java服务并设置开机自启动
  • javaEE-多线程进阶-JUC的常见类
  • Flume拦截器的实现
  • Swift Combine 学习(四):操作符 Operator
  • leetcode 173.二叉搜索树迭代器栈绝妙思路
  • df.groupby([pd.Grouper(freq=‘1M‘, key=‘Date‘), ‘Buyer‘]).sum()
  • LLM - 使用 LLaMA-Factory 部署大模型 HTTP 多模态服务 (4)
  • icp备案网站个人备案与企业备案的区别
  • 如何不修改模型参数来强化大语言模型 (LLM) 能力?
  • AF3 AtomAttentionEncoder类的init_pair_repr方法解读
  • DDoS攻击防御方案大全
  • Vue中常用指令
  • Servlet解析
  • 带虚继承的类对象模型
  • 深度学习中的离群值
  • 如何利用Logo设计免费生成器创建专业级Logo
  • Mysql SQL 超实用的7个日期算术运算实例(10k)
  • 运算指令(PLC)
  • 「Mac畅玩鸿蒙与硬件49」UI互动应用篇26 - 数字填色游戏
  • 机器学习经典算法——逻辑回归
  • 【数据仓库金典面试题】—— 包含详细解答
  • 【UE5 C++课程系列笔记】19——通过GConfig读写.ini文件
  • JS 中 json数据 与 base64、ArrayBuffer之间转换
  • USB 驱动开发 --- Gadget 驱动框架梳理
  • 细说STM32F407单片机中断方式CAN通信
  • Python应用指南:高德交通态势数据