当前位置: 首页 > news >正文

pytorch小记(三十一):深入解析 PyTorch 权重初始化:`xavier_normal_` 与 `constant_`

pytorch小记(三十一):深入解析 PyTorch 权重初始化:`xavier_normal_` 与 `constant_`

    • 深入解析 PyTorch 权重初始化:`xavier_normal_` 与 `constant_`
    • 一、为什么要初始化?
    • 二、`xavier_normal_` (Glorot 正态初始化)
      • 1. 原理
      • 2. PyTorch 接口
      • 3. 示例
    • 三、`constant_` (常数初始化)
      • 1. 原理与作用
      • 2. PyTorch 接口
      • 3. 示例
    • 四、何时选择哪种初始化?
    • 五、小结


深入解析 PyTorch 权重初始化:xavier_normal_constant_

在深度学习模型中,权重初始化 对训练的稳定性、收敛速度以及最终性能都有重要影响。PyTorch 提供了丰富的初始化函数,其中 xavier_normal_(也称 Glorot 正态初始化)和 constant_(常数初始化)是最常见的两种方式。本文将分别介绍它们的原理、应用场景与示例代码,帮助你在项目中灵活选用。


一、为什么要初始化?

未经合理初始化的权重可能会导致:

  • 梯度消失或爆炸:过小的初始值让信号逐层衰减,过大的初始值容易在反向传播时不断膨胀。
  • 收敛缓慢:不合适的分布可能使优化器难以找到合适的下降方向。
  • 对称破缺不足:所有权重相同或为零会让网络各神经元学习到相同特征。

因此,好的初始化策略能平衡信号流动,让训练更高效、更稳定。


二、xavier_normal_ (Glorot 正态初始化)

1. 原理

Xavier 初始化由 Xavier Glorot & Yoshua Bengio 在 2010 年提出,目标是在前向传播和反向传播时维持激活和梯度的方差不变。对于形状为 (out_features, in_features) 的权重矩阵 W

Wij∼N(0,2in_features+out_features).W_{ij} \sim \mathcal{N}\Bigl(0,\; \frac{2}{\text{in\_features} + \text{out\_features}}\Bigr). WijN(0,in_features+out_features2).

这里的方差:

Var(Wij)=2fan_in+fan_out.\mathrm{Var}(W_{ij}) = \frac{2}{\text{fan\_in} + \text{fan\_out}}. Var(Wij)=fan_in+fan_out2.

  • fan_in: 神经元输入通道数。
  • fan_out: 神经元输出通道数。

这种初始化能让网络信号保持在合理范围,防止深层网络训练困难。

2. PyTorch 接口

import torch
from torch.nn import init# 对某个层的权重做 Xavier 正态初始化
init.xavier_normal_(layer.weight, gain=1.0)
  • gain (float):缩放因子。对于非线性激活函数,可以根据激活类型传入不同的增益。例如 nn.ReLU 推荐 gain=\sqrt{2}
  • 如果网络层带有偏置(bias),常使用 init.constant_(layer.bias, 0) 将其初始化为 0。

3. 示例

import torch.nn as nn
from torch.nn import initclass MyModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(256, 128)# Xavier 正态初始化init.xavier_normal_(self.fc.weight, gain=init.calculate_gain('relu'))init.constant_(self.fc.bias, 0.0)def forward(self, x):return nn.functional.relu(self.fc(x))

三、constant_ (常数初始化)

1. 原理与作用

constant_ 将张量的所有元素都赋值为用户指定的常数 val

W_{ij} = val, \quad \forall i,j.
  • 常用于偏置(bias)初始化为 0 或小正值(如 0.01)。
  • 在 BatchNorm、LayerNorm 等操作中,将缩放参数 weight 初始化为 1,将偏移参数 bias 初始化为 0,也是常见做法。

2. PyTorch 接口

# 对所有元素赋常数
init.constant_(tensor, val)

3. 示例

import torch.nn as nn
from torch.nn import init# 假设一个卷积层
conv = nn.Conv2d(3, 16, kernel_size=3)
# 将卷积核权重初始化为常数 0.001
init.constant_(conv.weight, 0.001)
# 偏置初始化为 0
init.constant_(conv.bias, 0.0)

四、何时选择哪种初始化?

场景建议初始化方式
普通全连接/卷积层xavier_normal_(或 xavier_uniform_
ReLU/ReLU-like 激活gain=√2xavier_normal_
偏置(Bias)constant_(..., 0)
归一化层的缩放参数constant_(..., 1)
自定义小尺度初始值constant_(..., 小值)

五、小结

  • xavier_normal_:保持前向/反向方差平衡,适合大多数网络层,传入 gain 可适配不同激活函数。
  • constant_:最简单的全常数赋值,常用于偏置、归一化层参数的初始化。

合理的初始化能大幅提升训练稳定性和收敛速度,希望这篇文章能帮助你在项目中灵活选用并写出高效、鲁棒的模型代码。欢迎在评论区交流!

http://www.lryc.cn/news/591436.html

相关文章:

  • cuda编程笔记(8)--线程束warp
  • imx6ull-系统移植篇9——bootz启动 Linux 内核
  • Java全栈工程师面试实录:从电商支付到AI大模型架构的深度技术挑战
  • 软件项目管理学习笔记
  • S7-1200 模拟量模块全解析:从接线到量程计算
  • FreeRTOS学习笔记——常用函数说明
  • MQTT之CONNECT报文和CONNACK报文
  • Qwen3-8B Dify RAG环境搭建
  • @fullcalendar/vue 日历组件
  • SpringCloud面试笔记
  • 【每日刷题】跳跃游戏
  • Apache DolphinScheduler介绍与部署
  • 分布式光伏发电系统中的“四可”指的是什么?
  • 解读PLM系统软件在制造企业研发管理中的应用
  • 18650锂电池点焊机:新能源制造的精密纽带
  • AR智能巡检:制造业零缺陷安装的“数字监工”
  • Git仓库核心概念与工作流程详解:从入门到精通
  • 【java面试day6】redis缓存-数据淘汰策略
  • 二刷 黑马点评 秒杀优化
  • 全面升级!WizTelemetry 可观测平台 2.0 深度解析:打造云原生时代的智能可观测平台
  • Netty-基础知识
  • 【前端如何利用 localStorage 存储 Token 及跨域问题解决方案】
  • 7.17 Java基础 | 集合框架(下)
  • 【unitrix】 6.5 基础整数类型特征(base_int.rs)
  • 对比分析:给数据找个 “参照物”,让孤立数字变 “决策依据”
  • 数据呈现进阶:漏斗图与雷达图的实战指南
  • SQLite的可视化界面软件的安装
  • H3CNE 综合实验二解析与实施指南
  • 医院各类不良事件上报,PHP+vscode+vue2+element+laravel8+mysql5.7不良事件管理系统源代码,成品源码,不良事件管理系统
  • ASP .NET Core 8实现实时Web功能