当前位置: 首页 > news >正文

深度学习 | 表示学习 | 卷积神经网络 | Batch Normalization 在 CNN 中的示例 | 20

如是我闻: 让我们来用一个具体的例子说明 Batch Normalization 在 CNN 里的计算过程,特别是如何对每个通道(channel)进行归一化

在这里插入图片描述


1. 假设我们有一个 CNN 层的输出

假设某个 CNN 层的输出是一个 4D 张量,形状为:
X = ( m , C , H , W ) X = (m, C, H, W) X=(m,C,H,W)
其中:

  • m = 2 m = 2 m=2(batch 大小 = 2,即有 2 张图片)
  • C = 3 C = 3 C=3(通道数 = 3,比如 RGB 三个通道)
  • H = 2 , W = 2 H = 2, W = 2 H=2,W=2(特征图大小是 2 × 2 2 \times 2 2×2

现在,我们假设输入数据如下(仅展示一个通道的数据):

X = [ 样本 1: [ [ 1 , 2 ] [ 3 , 4 ] [ 5 , 6 ] [ 7 , 8 ] ] 样本 2: [ [ 2 , 3 ] [ 4 , 5 ] [ 6 , 7 ] [ 8 , 9 ] ] ] X = \begin{bmatrix} \text{样本 1:} & \begin{bmatrix} [1, 2] & [3, 4] \\ [5, 6] & [7, 8] \end{bmatrix} \\ \text{样本 2:} & \begin{bmatrix} [2, 3] & [4, 5] \\ [6, 7] & [8, 9] \end{bmatrix} \end{bmatrix} X= 样本 1:样本 2:[[1,2][5,6][3,4][7,8]][[2,3][6,7][4,5][8,9]]

这个数据表示的是 一个 batch(2 张图片),每张图片有一个 2 × 2 2 \times 2 2×2 特征图


2. 计算均值和方差

(1) 计算均值

我们需要对所有样本的这个通道进行归一化,所以我们计算该通道在所有样本上的均值:

μ B = 1 m × H × W ∑ i = 1 m ∑ j = 1 H ∑ k = 1 W x i , j , k \mu_B = \frac{1}{m \times H \times W} \sum_{i=1}^{m} \sum_{j=1}^{H} \sum_{k=1}^{W} x_{i, j, k} μB=m×H×W1i=1mj=1Hk=1Wxi,j,k

代入数据:
μ B = 1 2 × 2 × 2 ( 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 ) \mu_B = \frac{1}{2 \times 2 \times 2} (1+2+3+4+5+6+7+8 + 2+3+4+5+6+7+8+9) μB=2×2×21(1+2+3+4+5+6+7+8+2+3+4+5+6+7+8+9)
= 76 8 = 9.5 = \frac{76}{8} = 9.5 =876=9.5

(2) 计算方差

方差的计算公式:
σ B 2 = 1 m × H × W ∑ i = 1 m ∑ j = 1 H ∑ k = 1 W ( x i , j , k − μ B ) 2 \sigma_B^2 = \frac{1}{m \times H \times W} \sum_{i=1}^{m} \sum_{j=1}^{H} \sum_{k=1}^{W} (x_{i, j, k} - \mu_B)^2 σB2=m×H×W1i=1mj=1Hk=1W(xi,j,kμB)2

代入计算:
σ B 2 = 1 8 ( ( 1 − 4.75 ) 2 + ( 2 − 4.75 ) 2 + ( 3 − 4.75 ) 2 + ( 4 − 4.75 ) 2 + ⋯ + ( 9 − 4.75 ) 2 ) \sigma_B^2 = \frac{1}{8} \left( (1-4.75)^2 + (2-4.75)^2 + (3-4.75)^2 + (4-4.75)^2 + \dots + (9-4.75)^2 \right) σB2=81((14.75)2+(24.75)2+(34.75)2+(44.75)2++(94.75)2)

= 1 8 ( 14.06 + 7.56 + 3.06 + 0.56 + 0.56 + 3.06 + 7.56 + 14.06 ) = \frac{1}{8} \left( 14.06 + 7.56 + 3.06 + 0.56 + 0.56 + 3.06 + 7.56 + 14.06 \right) =81(14.06+7.56+3.06+0.56+0.56+3.06+7.56+14.06)

= 50.44 8 = 6.305 = \frac{50.44}{8} = 6.305 =850.44=6.305


3. 归一化数据

标准化公式:
x ^ i , j , k = x i , j , k − μ B σ B 2 + ϵ \hat{x}_{i,j,k} = \frac{x_{i,j,k} - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i,j,k=σB2+ϵ xi,j,kμB

假设 ϵ = 1 0 − 5 \epsilon = 10^{-5} ϵ=105很小,可以忽略不计,那么:
x ^ i , j , k = x i , j , k − 4.75 6.305 \hat{x}_{i,j,k} = \frac{x_{i,j,k} - 4.75}{\sqrt{6.305}} x^i,j,k=6.305 xi,j,k4.75

计算部分归一化的值(只展示部分):
x ^ 1 , 1 , 1 = 1 − 4.75 6.305 ≈ − 3.75 2.51 ≈ − 1.49 \hat{x}_{1,1,1} = \frac{1 - 4.75}{\sqrt{6.305}} \approx \frac{-3.75}{2.51} \approx -1.49 x^1,1,1=6.305 14.752.513.751.49
x ^ 1 , 1 , 2 = 2 − 4.75 2.51 ≈ − 1.10 \hat{x}_{1,1,2} = \frac{2 - 4.75}{2.51} \approx -1.10 x^1,1,2=2.5124.751.10
x ^ 2 , 2 , 2 = 9 − 4.75 2.51 ≈ 1.70 \hat{x}_{2,2,2} = \frac{9 - 4.75}{2.51} \approx 1.70 x^2,2,2=2.5194.751.70

经过这个过程,所有特征都会变成均值 0,方差 1


4. 通过可学习参数进行缩放和平移

为了让网络有更强的表达能力,BN 引入了两个可学习参数:
y i , j , k = γ x ^ i , j , k + β y_{i,j,k} = \gamma \hat{x}_{i,j,k} + \beta yi,j,k=γx^i,j,k+β

  • γ \gamma γ 控制缩放(scale)。
  • β \beta β 控制偏移(shift)。

如果 γ = 2 , β = 0.5 \gamma = 2, \beta = 0.5 γ=2,β=0.5,那么:
y 1 , 1 , 1 = 2 × ( − 1.49 ) + 0.5 = − 2.48 y_{1,1,1} = 2 \times (-1.49) + 0.5 = -2.48 y1,1,1=2×(1.49)+0.5=2.48
y 1 , 1 , 2 = 2 × ( − 1.10 ) + 0.5 = − 1.70 y_{1,1,2} = 2 \times (-1.10) + 0.5 = -1.70 y1,1,2=2×(1.10)+0.5=1.70
y 2 , 2 , 2 = 2 × ( 1.70 ) + 0.5 = 3.90 y_{2,2,2} = 2 \times (1.70) + 0.5 = 3.90 y2,2,2=2×(1.70)+0.5=3.90


5. 结果解释

(1) 归一化后,所有数据均值接近 0,方差接近 1

  • 这样可以稳定训练过程,防止梯度消失或梯度爆炸。

(2) 通过 γ \gamma γ β \beta β 让网络恢复部分信息

  • 这样可以确保 BN 不会限制网络的表达能力,同时还能优化训练。

6. 总的来说

  1. Batch Normalization 在 CNN 里是对每个通道单独归一化,而不是整个输入张量归一化
  2. 计算过程
    • 计算当前 batch 每个通道的均值和方差
    • 对该通道的所有数据进行归一化,使其均值为 0,方差为 1
    • 通过可学习参数 γ \gamma γ β \beta β 进行缩放和平移,使得网络仍然能够学习适应的特征分布。
  3. 最终作用
    • 减少 Internal Covariate Shift(内部协变量偏移)
    • 加速收敛,提高稳定性
    • 降低对超参数(如学习率、初始化)的依赖

以上

http://www.lryc.cn/news/531962.html

相关文章:

  • 最短木板长度
  • 团体程序设计天梯赛-练习集——L1-034 点赞
  • 利用腾讯云cloud studio云端免费部署deepseek-R1
  • LabVIEW的智能电源远程监控系统开发
  • Docker深度解析:安装各大环境
  • 牛客 - 链表相加(二)
  • GPU 硬件原理架构(一)
  • C/C++编译器
  • Immutable设计 SimpleDateFormat DateTimeFormatter
  • 最新EFK(Elasticsearch+FileBeat+Kibana)日志收集
  • Vue 3 30天精进之旅:Day 15 - 插件和指令
  • 【实战篇】Android安卓本地离线实现视频检测人脸
  • 【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-Chapter3-语言基础
  • (dpdk f-stack)-堆栈溢出-野指针-内存泄露(问题定位)
  • HTML5 教程之标签(3)
  • 【蓝桥】动态规划-简单-破损的楼梯
  • 如何自定义软件安装路径及Scoop包管理器使用全攻略
  • 107,【7】buuctf web [CISCN2019 华北赛区 Day2 Web1]Hack World
  • STM32 ADC单通道配置
  • 【技海登峰】Kafka漫谈系列(二)Kafka高可用副本的数据同步与选主机制
  • Spring的三级缓存如何解决循环依赖问题
  • Ext文件系统
  • 回溯算法---数独问题
  • 蓝桥杯python基础算法(2-1)——排序
  • 【课程笔记】信息隐藏与数字水印
  • Page Assist实现deepseek离线部署的在线搜索功能
  • composeUI中Box 和 Surface的区别
  • 【LeetCode】5. 贪心算法:买卖股票时机
  • MySQL表的CURD
  • Java 如何覆盖第三方 jar 包中的类