当前位置: 首页 > news >正文

深度学习入门简单实现一个神经网络

实现一个三层神经网络

  • 引言
    • 测试数据
  • 代码

引言

今天我们实现一个简单的神经网络
俩个输入神经元 隐藏层两个神经元 一个输出神经元
激活函数我们使用sigmoid
优化方法使用梯度下降

在这里插入图片描述

我们前期准备是需要把这些神经元的关系理清楚
x1:第一个输入
x2:第二个输入
w11_1:第一层的第一个神经元在第一个输入上的权重
w12_1:第一层的第二个神经元在第一个输入上的权重
b1_1:第一层第一个神经元的偏置项(截距)
z1_1:第一层第一个神经元的线性函数
a1_1:第一层第一个神经元的激活函数
w21_1:第一层的第一个神经元在第一个输入上的权重
w22_1:第一层的第一个神经元在第一个输入上的权重
b2_1:第一层第二个神经元的偏置项(截距)
z1_1:第一层第二个神经元的线性函数
a1_1:第一层第二个神经元的激活函数
w11_2:第二层的第一个神经元在第一个输入上的权重
w21_2:第二层的第二个神经元在第一个输入上的权重
b1_2:第二层第一个神经元的偏置项(截距)
z1_1:第二层第一个神经元的线性函数
a1_1:第二层第一个神经元的激活函数
e:损失函数

测试数据

dataset 可以使用西瓜书89页的西瓜数据集3.0α

代码

import numpy as np
import sympy
import dataset
from matplotlib import pyplot as pltdef sigmod(b):return 1 / (1 + np.exp(-b))xs, ys = dataset.get_beans(100)  # 获取数据
plt.title("Size-Toxicity Funciton", fontsize=12)  # 设置图片的标题
plt.xlabel("Bean Size")  # 设置行标签
plt.ylabel("Toxicity")  # 设置列标签plt.scatter(xs, ys)  # 画散点图"""
命名规则
下划线后面的数字表示被输入的神经元所在的层数
字母后面的数字表示第一个数字表示第几个输入
第二个数字表示被输入的神经元在他所在层数的位置
"""
# 第一层
# 第一个神经元
w11_1 = np.random.rand()
b1_1 = np.random.rand()
# 第二个神经元
w12_1 = np.random.rand()
b2_1 = np.random.rand()
# 第二层
w11_2 = np.random.rand()
w21_2 = np.random.rand()
b1_2 = np.random.rand()# 前向传播 代价函数 y0 = 1/(1+e^(-(wx+b)))
def forward_propgation(xs):z1_1 = w11_1 * xs + b1_1a1_1 = sigmod(z1_1) # 第一层第一个神经元的代价函数值z2_1 = w12_1 * xs + b2_1a2_1 = sigmod(z2_1)  # 第一层第二个神经元的代价函数值z1_2 = w11_2 * a1_1 + w21_2 * a2_1 + b1_2a1_2 = sigmod(z1_2)   # 第二层第一个神经元的代价函数值return a1_2, z1_2, a2_1, z2_1, a1_1, z1_1a1_2, z1_2, a2_1, z2_1, a1_1, z1_1 = forward_propgation(xs)
# plt.plot(xs, a1_2)
# plt.show()# 随机梯度下降
for j in range(5000):for i in range(100):x = xs[i]y = ys[i]# 先来一次前向传播a1_2, z1_2, a2_1, z2_1, a1_1, z1_1 = forward_propgation(x)# 开始反向传播# 误差代价函数e"""z1_1 = w11_1 * xs + b1_1a1_1 = sigmod(z1_1) # 第一层第一个神经元的代价函数值z2_1 = w12_1 * xs + b2_1a2_1 = sigmod(z2_1)  # 第一层第二个神经元的代价函数值z1_2 = w11_2 * a1_1 + w21_2 * a2_1a1_2 = sigmod(z1_2)   # 第二层第一个神经元的代价函数值"""e = (y - a1_2) ** 2  # 误差e = (y - 最后一个神经元得出的值)^2deda1_2 = -2*(y - a1_2)  # 对a1_2 第二层的第一个神经元的函数求导da1_2dz1_2 = a1_2 * (1 - a1_2)  # da1_2对dz1_2求导数dz1_2dw11_2 = a1_1  # dz1_2对w11_2求导数dz1_2dw21_2 = a2_1  # dz1_2对dw21_2求导dedw11_2 = deda1_2 * da1_2dz1_2 * dz1_2dw11_2  # de对dw11_2求偏导dedw21_2 = deda1_2 * da1_2dz1_2 * dz1_2dw21_2  # de对dw21_2求偏导dz1_2db1_2 = 1  # z1_2对db1_2求偏导dedb1_2 = deda1_2 * da1_2dz1_2 * dz1_2db1_2  # de对db1_2求偏导dz1_2da1_1 = w11_2  # dz1_2对da1_1求偏导da1_1dz1_1 = a1_1 * (1 - a1_1)  # da1_1对dz1_1 求偏导dz1_1dw11_1 = x  # dz1_1对dw11_1求偏导dedw11_1 = deda1_2 * da1_2dz1_2 * dz1_2da1_1 * da1_1dz1_1 * dz1_1dw11_1  # e对w11_1求导dz1_1db1_1 = 1  # z1_1对b1_1求导dedb1_1 = deda1_2 * da1_2dz1_2 * dz1_2da1_1 * da1_1dz1_1 * dz1_1db1_1  # e对b1_1求导dz1_2da2_1 = w21_2  # z1_2 对a2_1 求导da2_1dz2_1 = a2_1 * (1 - a2_1)  # a2_1 对z2_1求导dz2_1dw12_1 = x  # z2_1对w12_1dedw12_1 = deda1_2 * da1_2dz1_2 * dz1_2da1_1 * da2_1dz2_1 * dz2_1dw12_1  # e对w12_1求导dz2_1db2_1 = 1  # z2_1 对 b2_1求导dedb2_1 = deda1_2 * da1_2dz1_2 * dz1_2da1_1 * da2_1dz2_1 * dz2_1db2_1  # e 对 b2_1求导alpha = 0.03w11_2 = w11_2 - alpha * dedw11_2  # 调整w11_2w21_2 = w21_2 - alpha * dedw21_2  # 调整21_2b1_2 = b1_2 - alpha * dedb1_2  # 调整b1_2w12_1 = w12_1 - alpha * dedw12_1  # 调整w12_1b2_1 = b2_1 - alpha * dedb2_1  # 调整 b2_1w11_1 = w11_1 - alpha * dedw11_1  # 调整 w11_1b1_1 = b1_1 - alpha * dedb1_1  # 调整b1_1if j % 100 == 0:plt.clf()  # 清空窗口plt.scatter(xs, ys)a1_2, z1_2, a2_1, z2_1, a1_1, z1_1 = forward_propgation(xs)plt.plot(xs, a1_2)plt.pause(0.01)  # 暂停0.01秒
http://www.lryc.cn/news/328519.html

相关文章:

  • win11 环境配置 之 Jmeter(JDK17版本)
  • Windows下载使用nc(netcat)命令
  • istio 设置 istio-proxy sidecar 的 resource 的 limit 和 request
  • flutter弹框
  • 2013年认证杯SPSSPRO杯数学建模B题(第一阶段)流行音乐发展简史全过程文档及程序
  • 代码随想录算法训练营第39天 | 62.不同路径, 63不同路径II
  • Redis 的慢日志
  • 第十四届蓝桥杯第十题:蜗牛分享
  • 不懂技术的老板,如何避免过度依赖核心技术人员
  • Vue系列-el挂载
  • python--os和os.path模块
  • 前端通用命名规范和Vue项目命名规范
  • NTP服务搭建
  • Linux离线安装mysql,node,forever
  • WPF中获取TreeView以及ListView获取其本身滚动条进行滚动
  • C语言: 指针讲解
  • C#使用Stopwatch类来实现计时功能
  • ubuntu18.04安装qt
  • ElasticSearch、java的四大内置函数式接口、Stream流、parallelStream背后的技术、Optional类
  • 深入MNN:开源深度学习框架的介绍、安装与编译指南
  • [LeetCode][400]第 N 位数字
  • clickhouse 查询group 分组最大值的一行数据。
  • Python装饰器与生成器:从原理到实践
  • python-函数引入模块面向对象编程创建类继承
  • Spring:面试八股
  • Flask Python:请求上下文和应用上下文
  • 哔哩哔哩直播姬有线投屏教程
  • 您现在可以在家训练 70b 语言模型
  • 算法题剪格子使我重视起了编程命名习惯
  • P19:注释