当前位置: 首页 > news >正文

反向传播(backward propagation,BP) python实现

BP算法就是反向传播,要输入的数据经过一个前向传播会得到一个输出,但是由于权重的原因,所以其输出会和你想要的输出有差距,这个时候就需要进行反向传播,利用梯度下降,对所有的权重进行更新,这样的话在进行前向传播就会发现其输出和你想要的输出越来越接近了。

# 
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt# 生成权重以及偏执项layers_dim代表每层的神经元个数,
#比如[2,3,1]代表一个三成的网络,输入为2层,中间为3层输出为1层
def init_parameters(layers_dim):L = len(layers_dim)parameters ={}for i in range(1,L):parameters["w"+str(i)] = np.random.random([layers_dim[i],layers_dim[i-1]])parameters["b"+str(i)] = np.zeros((layers_dim[i],1))return parametersdef sigmoid(z):return 1.0/(1.0+np.exp(-z))# sigmoid的导函数
def sigmoid_prime(z):return sigmoid(z) * (1-sigmoid(z))# 前向传播,需要用到一个输入x以及所有的权重以及偏执项,都在parameters这个字典里面存储
# 最后返回会返回一个caches里面包含的 是各层的a和z,a[layers]就是最终的输出
def forward(x,parameters):a = []z = []caches = {}a.append(x)z.append(x)layers = len(parameters)//2# 前面都要用sigmoidfor i in range(1,layers):z_temp =parameters["w"+str(i)].dot(x) + parameters["b"+str(i)]z.append(z_temp)a.append(sigmoid(z_temp))# 最后一层不用sigmoidz_temp = parameters["w"+str(layers)].dot(a[layers-1]) + parameters["b"+str(layers)]z.append(z_temp)a.append(z_temp)caches["z"] = zcaches["a"] = a    return  caches,a[layers]# 反向传播,parameters里面存储的是所有的各层的权重以及偏执,caches里面存储各层的a和z
# al是经过反向传播后最后一层的输出,y代表真实值
# 返回的grades代表着误差对所有的w以及b的导数
def backward(parameters,caches,al,y):layers = len(parameters)//2grades = {}m = y.shape[1]# 假设最后一层不经历激活函数# 就是按照上面的图片中的公式写的grades["dz"+str(layers)] = al - ygrades["dw"+str(layers)] = grades["dz"+str(layers)].dot(caches["a"][layers-1].T) /mgrades["db"+str(layers)] = np.sum(grades["dz"+str(layers)],axis = 1,keepdims = True) /m# 前面全部都是sigmoid激活for i in reversed(range(1,layers)):grades["dz"+str(i)] = parameters["w"+str(i+1)].T.dot(grades["dz"+str(i+1)]) * sigmoid_prime(caches["z"][i])grades["dw"+str(i)] = grades["dz"+str(i)].dot(caches["a"][i-1].T)/mgrades["db"+str(i)] = np.sum(grades["dz"+str(i)],axis = 1,keepdims = True) /mreturn grades   # 就是把其所有的权重以及偏执都更新一下
def update_grades(parameters,grades,learning_rate):layers = len(parameters)//2for i in range(1,layers+1):parameters["w"+str(i)] -= learning_rate * grades["dw"+str(i)]parameters["b"+str(i)] -= learning_rate * grades["db"+str(i)]return parameters
# 计算误差值
def compute_loss(al,y):return np.mean(np.square(al-y))# 加载数据
def load_data():"""加载数据集"""x = np.arange(0.0,1.0,0.01)y =20* np.sin(2*np.pi*x)# 数据可视化plt.scatter(x,y)return x,y
#进行测试
x,y = load_data()
x = x.reshape(1,100)
y = y.reshape(1,100)
plt.scatter(x,y)
parameters = init_parameters([1,25,1])
al = 0
for i in range(4000):caches,al = forward(x, parameters)grades = backward(parameters, caches, al, y)parameters = update_grades(parameters, grades, learning_rate= 0.3)if i %100 ==0:print(compute_loss(al, y))
plt.scatter(x,al)
plt.show()

运行结果:

在这里插入图片描述

http://www.lryc.cn/news/318432.html

相关文章:

  • 简单算命脚本
  • Lua-掌握Lua语言基础1
  • python-0003-pycharm开发虚拟环境中的项目
  • 修改 MySQL update_time 默认值的坑
  • 基于亚马逊云EC2+Docker搭建nextcloud私有化云盘
  • 用try...catch进行判断
  • 服务器遭遇挖矿病毒syst3md及其伪装者rcu-sched:原因、症状与解决方案
  • 此机非彼机,工业计算机在工业行业的特殊地位
  • Python使用lxml解析XML格式化数据
  • CDA-LevelⅡ【考题整理-带答案】
  • 20240304 json可以包含复杂数组(数组里面套数组)
  • 算法50:动态规划专练(力扣514题:自由之路-----4种写法)
  • 重学SpringBoot3-集成Thymeleaf
  • 【数据可视化】Echarts最常用图表
  • flink:通过table api把文件中读取的数据写入MySQL
  • 【Java 多线程 哈希表】 HashTable, HashMap, ConcurrentHashMap 之间的区别
  • 有趣之matlab-烟花
  • C语言指针与数组(不适合初学者版):一篇文章带你深入了解指针与数组!
  • springboot Mongo大数据查询优化方案
  • Ollama管理本地开源大模型,用Open WebUI访问Ollama接口
  • Linux--基本知识入门
  • 基于springboot+vue实现的大学计算机课程管理平台的设计与实现(全套资料)
  • LeetCode2115. 从给定原材料中找到所有可以做出的菜
  • 项目性能优化—性能优化的指标、目标
  • 蓝桥杯刷题(三)
  • 20240312-算法复习打卡day21||● 530.二叉搜索树的最小绝对差 ● 501.二叉搜索树中的众数 ● 236. 二叉树的最近公共祖先
  • 今天我们来学习一下关于MySQL数据库
  • 长期护理保险可改善老年人心理健康 | CHARLS CLHLS CFPS 公共数据库周报(3.6)...
  • 49、C++/友元、常成员函数和常对象、运算符重载学习20240314
  • SQL Server错误:15404