当前位置: 首页 > news >正文

详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm

在深度学习中,标准化技术是提升模型训练速度、稳定性和性能的重要手段。本文将详细介绍三种常用的标准化方法:Batch Normalization(批量标准化)、Layer Normalization(层标准化)和 RMS Normalization(RMS标准化),并对其原理、实现和应用场景进行深入分析。

一、Batch Normalization

1.1 Batch Normalization的原理

Batch Normalization(BN)通过在每个小批量数据的每个神经元输出上进行标准化来减少内部协变量偏移。具体步骤如下:

  1. 计算小批量的均值和方差
    对于每个神经元的输出,计算该神经元在当前小批量中的均值和方差。

    [
    \muB = \frac{1}{m} \sum{i=1}^m x_i
    ]

    [
    \sigmaB^2 = \frac{1}{m} \sum{i=1}^m (x_i - \mu_B)^2
    ]

  2. 标准化
    使用计算得到的均值和方差对数据进行标准化。

    [
    \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
    ]

  3. 缩放和平移
    引入可学习的参数进行缩放和平移。

    [
    y_i = \gamma \hat{x}_i + \beta
    ]

    其中,(\gamma)和(\beta)是可学习的参数。

1.2 Batch Normalization的实现

在PyTorch中,Batch Normalization可以通过 torch.nn.BatchNorm2d实现。

import torch
import torch.nn as nn# 创建BatchNorm层
batch_norm = nn.BatchNorm2d(num_features=64)# 输入数据
x = torch.randn(16, 64, 32, 32)  # (batch_size, num_features, height, width)# 应用BatchNorm
output = batch_norm(x)
​

1.3 Batch Normalization的优缺点

优点
  • 加速训练:通过减少内部协变量偏移,加快了模型收敛速度。
  • 稳定性提高:减小了梯度消失和爆炸的风险。
  • 正则化效果:由于引入了噪声,有一定的正则化效果。
缺点
  • 依赖小批量大小:小批量大小过小时,均值和方差估计不准确。
  • 训练和推理不一致:训练时使用小批量的均值和方差,推理时使用整个数据集的均值和方差。

二、Layer Normalization

2.1 Layer Normalization的原理

Layer Normalization(LN)通过在每一层的神经元输出上进行标准化,独立于小批量的大小。具体步骤如下:

  1. 计算每一层的均值和方差
    对于每一层的神经元输出,计算其均值和方差。

    [
    \muL = \frac{1}{H} \sum{i=1}^H x_i
    ]

    [
    \sigmaL^2 = \frac{1}{H} \sum{i=1}^H (x_i - \mu_L)^2
    ]

  2. 标准化
    使用计算得到的均值和方差对数据进行标准化。

    [
    \hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}}
    ]

  3. 缩放和平移
    引入可学习的参数进行缩放和平移。

    [
    y_i = \gamma \hat{x}_i + \beta
    ]

    其中,(\gamma)和(\beta)是可学习的参数。

2.2 Layer Normalization的实现

在PyTorch中,Layer Normalization可以通过 torch.nn.LayerNorm实现。

import torch
import torch.nn as nn# 创建LayerNorm层
layer_norm = nn.LayerNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用LayerNorm
output = layer_norm(x)
​

2.3 Layer Normalization的优缺点

优点
  • 与小批量大小无关:适用于小批量训练和在线学习。
  • 更适合RNN:在循环神经网络中表现更好,因为它独立于时间步长。
缺点
  • 计算开销较大:每一层都需要计算均值和方差,计算开销较大。
  • 对CNN效果不明显:在卷积神经网络中效果不如BN明显。

三、RMS Normalization

3.1 RMS Normalization的原理

RMS Normalization(RMSNorm)通过标准化每一层的RMS值,而不是均值和方差。具体步骤如下:

  1. 计算RMS值
    对于每一层的神经元输出,计算其RMS值。

    [
    \text{RMS}(x) = \sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2}
    ]

  2. 标准化
    使用计算得到的RMS值对数据进行标准化。

    [
    \hat{x}_i = \frac{x_i}{\text{RMS}(x) + \epsilon}
    ]

  3. 缩放和平移
    引入可学习的参数进行缩放和平移。

    [
    y_i = \gamma \hat{x}_i + \beta
    ]

    其中,(\gamma)和(\beta)是可学习的参数。

3.2 RMS Normalization的实现

在PyTorch中,RMS Normalization没有直接的内置实现,可以通过自定义层来实现。

import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, normalized_shape, epsilon=1e-8):super(RMSNorm, self).__init__()self.epsilon = epsilonself.gamma = nn.Parameter(torch.ones(normalized_shape))self.beta = nn.Parameter(torch.zeros(normalized_shape))def forward(self, x):rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.epsilon)x = x / rmsreturn self.gamma * x + self.beta# 创建RMSNorm层
rms_norm = RMSNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用RMSNorm
output = rms_norm(x)
​

3.3 RMS Normalization的优缺点

优点
  • 计算效率高:计算RMS值相对简单,计算开销较小。
  • 稳定性好:在某些任务中可以表现出更好的稳定性。
缺点
  • 应用较少:相较于BN和LN,应用场景和研究较少。
  • 效果不确定:在某些情况下效果可能不如BN和LN显著。

四、比较与应用场景

4.1 比较

特性Batch NormLayer NormRMSNorm
标准化维度小批量内各特征维度每层各特征维度每层各特征维度的RMS
计算开销中等较大较小
对小批量大小依赖依赖不依赖不依赖
应用场景CNN、MLPRNN、Transformer各类神经网络
正则化效果有一定正则化效果无显著正则化效果无显著正则化效果

4.2 应用场景

  • Batch Normalization

    • 适用于卷积神经网络(CNN)和多层感知机(MLP)。
    • 对小批量大小有依赖,不适合小批量和在线学习。
  • Layer Normalization

    • 适用于循环神经网络(RNN)和Transformer。
    • 独立于小批量大小,适合小批量和在线学习。
  • RMS Normalization

    • 适用于各种神经网络,尤其在计算效率和稳定性有要求的任务中。
    • 相对较新,应用场景和研究较少,但在某些任务中可能表现优异。

五、总结

Batch Normalization

、Layer Normalization和RMS Normalization是深度学习中常用的标准化技术。它们各有优缺点,适用于不同的应用场景。通过理解其原理和实现,您可以根据具体需求选择合适的标准化方法,提升模型的训练速度和性能。

http://www.lryc.cn/news/525844.html

相关文章:

  • linux+docker+nacos+mysql部署
  • 如何实现gitlab和jira连通
  • 利用ML.NET精准提取人名
  • Node.js的解释
  • Macos下交叉编译安卓的paq8px压缩算法
  • 如何在data.table中处理缺失值
  • 从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑
  • SQL-leetcode—1164. 指定日期的产品价格
  • [Day 15]54.螺旋矩阵(简单易懂 有画图)
  • HTTP 配置与应用(不同网段)
  • Quartus:开发使用及 Tips 总结
  • VSCode下EIDE插件开发STM32
  • Golang并发机制及CSP并发模型
  • HTML 文本格式化详解
  • 我谈《概率论与数理统计》的知识体系
  • 五、华为 RSTP
  • 基于Java Web的网上房屋租售网站
  • Pyside6(PyQT5)中的QTableView与QSqlQueryModel、QSqlTableModel的联合使用
  • git常用命令学习
  • 【优选算法】7----三数之和
  • 分子动力学模拟里的术语:leap-frog蛙跳算法和‌Velocity-Verlet算法
  • 2025年数学建模美赛:A题分析(1)Testing Time: The Constant Wear On Stairs
  • 利用 SoybeanAdmin 实现前后端分离的企业级管理系统
  • 996引擎 - 前期准备-配置开发环境
  • Tensor 基本操作4 理解 indexing,加减乘除和 broadcasting 运算 | PyTorch 深度学习实战
  • 【Uniapp-Vue3】request各种不同类型的参数详解
  • 【Prometheus】Prometheus如何监控Haproxy
  • SSM开发(一)JAVA,javaEE,spring,springmvc,springboot,SSM,SSH等几个概念区别
  • HTML5 常用事件详解
  • TCP全连接队列