当前位置: 首页 > news >正文

人工智能之Tensorflow批标准化

批标准化(Batch Normalization,BN)是为了克服神经网络层数加深导致难以训练而诞生的。

随着神经网络的深度加深,训练会越来越困难,收敛速度会很慢,常常会导致梯度消失问题。梯度消失问题是在神经网络中,当前隐藏层的学习速率低于后面隐藏层的学习速率,即随着隐藏层数目的增加分类准确率反而下降,这种现象叫梯度消失问题。

传统机器学习中有一个ICS理论,这是一个经典假设:源域(Source Domain)和目标域(Target Domain)的数据分布是一致的,也就是说,训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障

协变量转移(Covariate Shift)是指当训练集的样本数据和目标集样本分布不一致时,训练得到的模型无法很好地泛化。它是分布不一致假设之下的分支,也就是之源域和目标域的条件概率是一致的,但是边缘概率不同。
对于神经网络的各层输出,在经过层内操作之后,各层输出分布就会与对应的输入信号分布不同,而且差异会随着网络深度增大而增大,但是每一层所指向的样本标记仍然是不变的。

解决思路:根据训练样本的比例对训练样本做一个矫正,因此,通过引入批标准化来规范某些层或者所有层的输入,从而固定每层输入信号的均值与方差

批标准化一般用在非线性映射(激活函数)之前,对于 x = W u + b x=Wu+b x=Wu+b做规范化,使结果(输出信号各个维度)的均值为0,方差为1。让每一层的输入有一个稳定的分布会有利于网络的训练。批标准化通过规范化让激活函数分布在线性区间,结果就是加大了梯度,让模型更加大胆地进行梯度下降。

批标准化具有以下几个优点:

  1. 加大探索的步长,从而加快收敛的速度。
  2. 更容易跳出局部最小值。
  3. 破坏原来的数据分布,在一定程度上缓解过拟合。

对每一次的Wx_plus_b 进行批标准化,这个步骤放在激活函数之前,示例片段如下:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
out_size=10
w=tf.Variable(tf.ones([out_size]))
u=tf.Variable(tf.ones([out_size]))
b=tf.Variable(tf.ones([out_size]))
Wx_plus_b=w*u+b
#计算Wx_plus_b的均值和方差,其中axex=[0]表示想要标准化的维度
fc_mean,fc_var=tf.nn.moments(Wx_plus_b,axes=[0])
scale=tf.Variable(tf.ones([out_size]))
shift=tf.Variable(tf.zeros([out_size]))
epsilon=0.001Wx_plus_b=tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
#下面两步等同用于上面一步
#Wx_plus_b=(Wx_plus_b-fc_mean)/tf.sqrt(fc_var+0.001)
#Wx_plus_b=Wx_plus_b?scale+shift

在这里插入图片描述

http://www.lryc.cn/news/324271.html

相关文章:

  • 自动化的免下车服务——银行、餐厅、快餐店、杂货店
  • Git常用指令总结
  • 水果软件FL Studio 21 for mac 21.2.3.3586破解版的最新版本2024介绍安装
  • 【保姆级】前端使用node.js基础教程
  • xilinx的高速接口构成原理和连接结构
  • git 上传文件夹至远端仓库的方法
  • 【鸿蒙系统】 ---OpenHarmony加快本地编译(二)
  • centos配置natapp 自动配置
  • sell脚本多行合成一行
  • 部署prometheus 监控k8s集群
  • 两个基本功不足导致的bug
  • 【算法每日一练]-图论(保姆级教程篇16 树的重心 树的直径)#树的直径 #会议 #医院设置
  • Qt播放音乐代码示例
  • 多线程应用中的性能优化:创建合适的线程数
  • 本地运行环境工具UPUPWANK(win)和Navicat数据库管理工具
  • LeetCode 每日一题 2024/3/18-2024/3/24
  • Unity 鼠标拖拽3D物体跟随移动的方法
  • 数据分析-Pandas分类数据的类别排序和顺序
  • 利用 Claude 3 on Amazon Bedrock 和 Streamlit 的“终极组合”,开发智能对话体验
  • Golang基础 Label标签与goto跳转
  • 二进制王国(蓝桥杯备赛)【sort/cmp的灵活应用】
  • 活用C语言之宏定义应用大全
  • 【源码】I.MX6ULL移植OpenCV
  • pytorch深度学习——dataset(附数据集下载)
  • springboot+vue考试管理系统
  • 自动驾驶建图--道路边缘生成方案探讨
  • 图片编辑器中实现文件上传的三种方式和二进制流及文件头校验文件类型
  • 深度学习,CRNN+CTC和Attention OCR你更青睐哪一种?
  • 飞桨AI应用@riscv OpenKylin
  • 在MongoDB建模1对N关系的基本方法