当前位置: 首页 > news >正文

Muon:神经网络隐藏层的革命性优化器

Muon是一种针对神经网络隐藏层参数的新型优化器,已在NanoGPT和CIFAR-10训练速度上创下新记录。本博客将深入解析其设计原理、性能优势及实现细节。


🚀 突破性成果

Muon在多个关键任务中实现显著加速:

  1. CIFAR-10训练:准确率94%的耗时从3.3 A100秒降至2.6 A100秒
  2. NanoGPT训练:验证损失3.28的耗时降低1.35倍
  3. 大模型训练
    • 1.5B参数模型达GPT-2 XL性能仅需10小时(8×H100)
    • 比AdamW节省25%训练时间

图:Muon在样本效率和墙钟时间上均优于主流优化器


⚙️ 核心技术:牛顿-舒尔茨正交化

Muon的核心创新在于对SGD动量更新进行正交化处理:

算法流程

def newtonschulz5(G, steps=5, eps=1e-7):a, b, c = (3.4445, -4.7750, 2.0315)  # 调优系数X = G.bfloat16()X /= (X.norm() + eps)if G.size(0) > G.size(1): X = X.Tfor _ in range(steps):A = X @ X.TB = b*A + c*A@AX = a*X + B@Xreturn X if G.size(0) <= G.size(1) else X.T

数学原理
G = U S V ⊤ (SVD分解) G ′ = U ( a S + b S 3 + c S 5 ) V ⊤ \begin{align*} G &= USV^\top \quad \text{(SVD分解)} \\ G' &= U(aS + bS^3 + cS^5)V^\top \end{align*} GG=USV(SVD分解)=U(aS+bS3+cS5)V
通过迭代使更新矩阵趋近正交矩阵 U V ⊤ UV^\top UV


🧪 关键设计决策

  1. 为何选择正交化?

    • 实证发现:Adam/SGD的更新矩阵条件数极高(接近低秩)
    • 正交化可增强小幅度更新方向的重要性
  2. 为何不用SVD?

    • SVD计算效率低(比NS迭代慢10倍以上)
    • NS迭代可在bfloat16下稳定运行
  3. 系数调优 ( 3.4445 , − 4.7750 , 2.0315 ) (3.4445, -4.7750, 2.0315) (3.4445,4.7750,2.0315)

    • 最大化收敛速度:增大 a a a加速小奇异值收敛
    • 控制误差范围: lim ⁡ N → ∞ ϕ N ( x ) ∈ [ 0.7 , 1.3 ] \lim_{N\to\infty}\phi^N(x)\in[0.7,1.3] limNϕN(x)[0.7,1.3]

⏱️ 极致效率:仅1%额外开销

计算复杂度分析
FLOP开销 = T × m B \text{FLOP开销} = \frac{T \times m}{B} FLOP开销=BT×m

  • T = 5 T=5 T=5(NS迭代步数)
  • m m m:模型维度
  • B B B:批处理token数
训练场景模型维度Batch Size开销
NanoGPT (768M)768524,2880.7%
LLaMA 3 (405B)16,38416,000,0000.5%

🔄 与经典优化器的关系

  1. Shampoo

    • Muon ≈ 动量版"瞬时Shampoo"(无累加器)
    • 避免Shampoo的高内存消耗问题
  2. 正交-SGDM

    • Muon将动量置于正交化之前
    • 用NS迭代替代计算昂贵的SVD

🛠️ 实际使用指南

  1. 适用范围

    • 仅处理2D参数(全连接层权重)
    • 卷积层需展平后处理(conv_weight.view(C_out, -1)
  2. 混合优化策略

    # PyTorch示例
    optimizer = torch.optim.AdamW([{'params': model.embeddings},    # 输入层{'params': model.hidden_layers, 'optimizer': Muon()},  # 隐藏层{'params': model.head}           # 输出层
    ])
    
  3. 最佳实践

    • 输入/输出层使用AdamW
    • 采用Nesterov动量(比标准动量提升3-5%)
    • Q/K/V参数分开优化(比联合优化效果更好)

📜 研究范式革命:竞争性任务验证

Muon通过标准化基准测试避免常见研究陷阱:

  1. NanoGPT速度竞赛作为验证场:
    • 基线=当前最佳记录(已充分调优)
    • 新方法必须实际部署验证(非纸面对比)
  2. 自我修正机制
    • 若AdamW更优,可轻易替换Muon刷新记录
    • Muon持续保持记录12次(7位研究者验证)

“你无需信任我,只需信任想破记录的研究者们” —— Keller Jordan


❓ 待解问题

  1. 扩展性:能否支持>20B参数的万亿token训练?
  2. 分布式:如何在GPU集群高效部署NS迭代?
  3. 任务泛化:是否适用于微调/强化学习?

Muon的核心优势在于其独特的正交化设计,这种设计解决了传统优化器在神经网络训练中的关键痛点。以下从优势设计原理两个维度解析:


🔥 Muon的五大核心优势

  1. 解决梯度方向失衡问题

    • 问题:传统优化器(如AdamW)的更新矩阵常呈病态条件数(奇异值差异达10³倍),导致少数方向主导更新
    • 方案:正交化强制所有更新方向具有相同权重,避免小奇异值方向被淹没
    • 效果:提升模型对低频特征的捕捉能力(尤其关键于语言建模)
  2. 逼近理论最优更新

    • 数学证明:正交化更新等价于SVD分解后的 U V ⊤ UV^\top UV
      Muon ( G ) = arg ⁡ min ⁡ O ∥ O − G ∥ F s.t.  O ⊤ O = I \text{Muon}(G) = \underset{O}{\arg\min} \|O - G\|_F \quad \text{s.t.} \ O^\top O = I Muon(G)=OargminOGFs.t. OO=I
    • 物理意义:在Frobenius范数下找到最接近原始梯度的正交矩阵
  3. 计算效率革命

    方法计算复杂度硬件友好性
    SVD O ( n m 2 ) O(nm^2) O(nm2)差(需高精度)
    牛顿-舒尔茨迭代 O ( n m 2 ) O(nm^2) O(nm2)极佳(支持bfloat16)
    • 5步迭代即可达到 ε < 0.3 \varepsilon<0.3 ε<0.3的实用精度(传统方法需>20步)
  4. 内存优化

    • 零额外参数缓存:相比Shampoo减少 O ( m 2 ) O(m^2) O(m2)级内存消耗
    • 例如:4096维参数层,Shampoo需67MB额外内存,Muon仅需0.1MB
  5. 训练加速实证

    Muon替换
    AdamW
    35%训练速度提升
    10小时训练1.5B模型
    达GPT-2 XL性能

🧠 正交化分解的设计逻辑

Muon选择牛顿-舒尔茨迭代实现正交化,源于三层关键设计考量:

1. 为何必须正交化?
  • 神经网络的几何结构特性
    • 隐藏层参数本质是流形映射(Manifold Learning)
    • 正交更新保持特征空间的等距变换(Isometry),避免训练过程中空间扭曲
  • 理论支持
    ∇ ortho L = arg ⁡ min ⁡ ∥ δ W ∥ spec ≤ η L ( W + δ W ) \nabla_{\text{ortho}} \mathcal{L} = \underset{\| \delta W \|_{\text{spec}} \leq \eta}{\arg \min} \mathcal{L}(W + \delta W) orthoL=δWspecηargminL(W+δW)
    证明正交更新是谱范数约束下的最优扰动(Bernstein & Newhouse, 2024)
2. 为何选择牛顿-舒尔茨而非SVD?
维度SVD牛顿-舒尔茨迭代
数值稳定性需要float32bfloat16即可
并行性GPU利用率低95%+ Tensor Core占用
迭代收敛不可控5步收敛
  • 硬件适配:NS迭代的矩阵连乘形式完美匹配GPU的SIMD架构
3. 系数 ( 3.4445 , − 4.7750 , 2.0315 ) (3.4445, -4.7750, 2.0315) (3.4445,4.7750,2.0315)的数学意义
  • 优化目标:最大化 φ ( x ) = a x + b x 3 + c x 5 \varphi(x)=ax+bx^3+cx^5 φ(x)=ax+bx3+cx5 [ 0 , 1 ] [0,1] [0,1]的收敛速度
  • 调优原理
    max ⁡ a s.t. lim ⁡ N → ∞ φ N ( x ) ∈ [ 0.7 , 1.3 ] \max a \quad \text{s.t.} \quad \lim_{N→∞} \varphi^N(x) \in [0.7,1.3] maxas.t.NlimφN(x)[0.7,1.3]
    • a = 3.4445 a=3.4445 a=3.44453倍于基线值(1.15),加速小奇异值收敛
    • b b b值:抑制中段奇异值的过冲现象
  • 效果验证
    # 迭代5次后奇异值分布
    baseline = [0.12, 0.38, 0.91]  # (2,-1.5,0.5)
    tuned    = [0.89, 0.93, 0.97]  # Muon系数
    

🌟 设计哲学:面向硬件的算法革新

Muon的分解策略体现了计算-理论协同设计的新范式:

  1. 从问题出发:识别梯度方向失衡是训练瓶颈
  2. 理论映射:将优化问题转化为矩阵正交逼近
  3. 硬件反推设计
    • 利用GPU的Tensor Core特性:选择矩阵连乘而非分解
    • 拥抱低精度计算:设计数值稳定的迭代格式
  4. 工程验证:通过NanoGPT速度竞赛实现算法有效性验证

“Muon不是发现了新数学,而是用硬件语言重构了优化理论” — Keller Jordan

这种设计使得Muon在维持理论严谨性的同时,成为首个能在实际训练任务中显著超越AdamW的优化器。


正交化能强制所有更新方向具有相同权重的本质在于奇异值的归一化,这直接改变了梯度更新的几何结构。以下是分层解析:


1️⃣ 数学本质:奇异值的等权重置

设原始梯度矩阵 G ∈ R m × n G \in \mathbb{R}^{m \times n} GRm×n 的SVD分解为:
G = U Σ V ⊤ , Σ = diag ( σ 1 , σ 2 , … , σ r ) G = U \Sigma V^\top, \quad \Sigma = \text{diag}(\sigma_1, \sigma_2, \dots, \sigma_r) G=UΣV,Σ=diag(σ1,σ2,,σr)
其中 σ 1 ≥ σ 2 ≥ ⋯ ≥ σ r > 0 \sigma_1 \geq \sigma_2 \geq \dots \geq \sigma_r > 0 σ1σ2σr>0 为奇异值。

  • 正交化操作
    Ortho ( G ) = U V ⊤ = U ⋅ I ⋅ V ⊤ \text{Ortho}(G) = UV^\top = U \cdot I \cdot V^\top Ortho(G)=UV=UIV
    实质是将奇异值矩阵 Σ \Sigma Σ 替换为单位矩阵 I I I
    ( σ 1 ⋱ σ r ) → 正交化 ( 1 ⋱ 1 ) \begin{pmatrix} \sigma_1 & & \\ & \ddots & \\ & & \sigma_r \end{pmatrix} \xrightarrow{\text{正交化}} \begin{pmatrix} 1 & & \\ & \ddots & \\ & & 1 \end{pmatrix} σ1σr 正交化 11

  • 几何意义
    原始梯度空间中,不同方向的更新幅度由 σ i \sigma_i σi 决定(最大方向 σ 1 \sigma_1 σ1 可能是最小方向 σ r \sigma_r σr 10 3 10^3 103 倍)。
    正交化后所有奇异值被强制设为1,即所有更新方向获得完全相同的幅度权重。


2️⃣ 物理意义:消除梯度主导方向

▶ 原始梯度的问题

假设某全连接层梯度 G G G 的奇异值分布:
σ 1 = 100 , σ 2 = 10 , σ 3 = 0.1 \sigma_1=100, \ \sigma_2=10, \ \sigma_3=0.1 σ1=100, σ2=10, σ3=0.1

  • 方向1的更新强度是方向3的 1000倍
  • 方向3(可能对应重要低频特征)的更新被淹没
▶ 正交化后的效果

Ortho ( G ) = U ( 1 0 0 0 1 0 0 0 1 ) V ⊤ \text{Ortho}(G) = U \begin{pmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{pmatrix} V^\top Ortho(G)=U 100010001 V

  • 三个方向更新强度均为 1.0
  • 方向3的权重从 0.1 0.1 0.1 1.0 1.0 1.0相对增强10倍
  • 方向1的权重从 100 100 100 1.0 1.0 1.0相对抑制99%

3️⃣ 几何视角:球面约束空间

正交化等价于将梯度更新投影到正交群流形(Orthogonal Group Manifold) 上:

高曲率
零曲率
原始梯度空间
非均匀更新
正交群流形
各向同性更新
  • 正交群 O ( n ) \mathbf{O}(n) O(n) 的性质
    ∀ v ⃗ i , v ⃗ j ∈ Ortho ( G ) : ⟨ v ⃗ i , v ⃗ j ⟩ = δ i j \forall \vec{v}_i, \vec{v}_j \in \text{Ortho}(G): \ \langle \vec{v}_i, \vec{v}_j \rangle = \delta_{ij} v i,v jOrtho(G): v i,v j=δij
    所有更新方向彼此正交且长度严格为1,构成标准正交基

  • 优化意义:在正交群流形上,参数更新等价于旋转而非缩放,避免了某些方向过度主导。


4️⃣ 与经典方法的对比

方法更新形式方向权重特性
SGD − η G -\eta G ηG ∝ σ i \propto \sigma_i σi
AdamW − η G v -\eta \frac{G}{\sqrt{v}} ηv G减弱大 σ i \sigma_i σi,但不等权
Muon − η U V ⊤ -\eta UV^\top ηUV σ i ≡ 1 \sigma_i \equiv 1 σi1

AdamW通过自适应学习率调整奇异值的量级,但不改变其相对比例
Muon直接消除奇异值的差异,实现严格等权。


5️⃣ 实际效果验证

在Transformer训练中,层权重矩阵的梯度奇异值分布通常满足幂律分布:
σ k ∝ k − α ( α ≈ 1.5 ) \sigma_k \propto k^{-\alpha} \quad (\alpha \approx 1.5) σkkα(α1.5)

  • 正交化前:Top 5%方向占据80%更新能量
  • 正交化后
    • 低频方向(大 k k k)更新强度提升 8-12倍
    • 高频方向(小 k k k)更新强度降低 50-100倍
  • 结果:模型更有效地学习到长尾特征(如低频语义关联)

✨ 核心结论

正交化通过奇异值归一化,将梯度更新的谱分布从病态幂律分布转化为均匀分布:
gradient energy → ortho uniformly distributed \text{gradient energy} \xrightarrow{\text{ortho}} \text{uniformly distributed} gradient energyortho uniformly distributed
这本质是一种梯度白化(Gradient Whitening),使网络在更新时平等对待所有学习方向,从而显著提升对弱特征的捕获能力。


Muon中三个关键系数a=3.4445, b=-4.7750, c=2.0315的计算过程是一个收敛性-效率权衡的优化问题,其设计遵循严格的数学约束。以下是完整的推导逻辑:


🔧 系数设计的三大约束条件

  1. 收敛性约束
    lim ⁡ N → ∞ ϕ N ( x ) ∈ [ 0.7 , 1.3 ] ∀ x ∈ [ 0 , 1 ] \lim_{N\to\infty} \phi^N(x) \in [0.7, 1.3] \quad \forall x\in[0,1] NlimϕN(x)[0.7,1.3]x[0,1]
    (其中 ϕ ( x ) = a x + b x 3 + c x 5 \phi(x)=ax+bx^3+cx^5 ϕ(x)=ax+bx3+cx5

  2. 效率最大化
    max ⁡ a (控制小奇异值收敛速度) \max a \quad \text{(控制小奇异值收敛速度)} maxa(控制小奇异值收敛速度)

  3. 迭代稳定性
    ∣ ϕ ′ ( x ) ∣ < 1 在 [ 0 , 1 ] 上保证收敛 |\phi'(x)| < 1 \quad \text{在} [0,1] \text{上保证收敛} ϕ(x)<1[0,1]上保证收敛


📐 分步求解过程

步骤1:建立收敛性边界模型

定义误差函数:
E ( a , b , c ) = ∫ 0 1 ∣ lim ⁡ N → ∞ ϕ N ( x ) − 1 ∣ 2 d x E(a,b,c) = \int_0^1 \left| \lim_{N\to\infty}\phi^N(x) - 1 \right|^2 dx E(a,b,c)=01 NlimϕN(x)1 2dx

约束转化为:
0.7 ≤ lim ⁡ N → ∞ ϕ N ( x ) ≤ 1.3 0.7 \leq \lim_{N\to\infty}\phi^N(x) \leq 1.3 0.7NlimϕN(x)1.3

步骤2:分析多项式不动点

固定点满足 ϕ ( x ) = x \phi(x)=x ϕ(x)=x,解得:
x = 0 或 a + b x 2 + c x 4 = 1 x=0 \quad \text{或} \quad a + b x^2 + c x^4 = 1 x=0a+bx2+cx4=1

期望不动点 x = 1 x=1 x=1 稳定,要求:
ϕ ′ ( 1 ) = a + 3 b + 5 c < 1 \phi'(1)=a+3b+5c < 1 ϕ(1)=a+3b+5c<1

步骤3:梯度优化算法

采用投影梯度法迭代求解:

def optimize_coeffs():a, b, c = 2.0, -1.5, 0.5  # 初始基准值lr = 0.01for epoch in range(10000):# 前向传播计算收敛值x = np.linspace(0, 1, 1000)y = fixed_point_iteration(phi, x, N=100)  # 迭代100次模拟极限# 计算损失和梯度loss = np.mean(np.clip(y, 0.7, 1.3) - 1)**2grad_a = 2 * np.mean((y-1)*x * dphi_da(x))  # 链式求导... # b,c梯度类似# 梯度投影更新a += lr * grad_aa = np.clip(a, 2.5, 4.0)  # 约束a范围... # 类似处理b,c# 强制满足不动点约束if a + 3*b + 5*c >= 1:c = (1 - a - 3*b)/5 * 0.99  # 松弛因子return a, b, c

📊 关键优化技巧

  1. 小奇异值加速策略
    增大 a a a显著提升小 x x x收敛:
    ϕ ′ ( 0 ) = a ⇒ 迭代步长 ∝ a k \phi'(0) = a \quad \Rightarrow \quad \text{迭代步长} \propto a^k ϕ(0)=a迭代步长ak

    a a a达到0.9精度所需迭代步数
    2.08
    3.05
    3.44453
  2. 中段振荡抑制
    b b b值(-4.775)的设计:
    ∂ ϕ ∂ b = x 3 ⇒ b < 0 抑制 x ∈ [ 0.3 , 0.7 ] 的过冲 \frac{\partial \phi}{\partial b} = x^3 \quad \Rightarrow \quad b<0 \text{ 抑制} x\in[0.3,0.7]\text{的过冲} bϕ=x3b<0 抑制x[0.3,0.7]的过冲

    # b的梯度更新规则
    if np.max(y[300:700]) > 1.2:grad_b -= penalty * 10  # 对中段过冲强惩罚
    
  3. 高次项平衡设计
    系数 c c c的互补作用:
    c x 5 补偿  ∣ b x 3 ∣ 在 x > 0.8 的欠收敛 c x^5 \text{ 补偿 } |b x^3| \text{ 在} x>0.8\text{ 的欠收敛} cx5 补偿 bx3 x>0.8 的欠收敛

    # c的约束条件
    c_min = (1 - a - 3*b)/5 * 0.95  # 稳定性下限
    c_max = (1.3 - a - 3*b)/5       # 收敛性上限
    

⚖️ 最终系数解析

( a , b , c ) = ( 3.4445 , − 4.7750 , 2.0315 ) (a,b,c) = (3.4445, -4.7750, 2.0315) (a,b,c)=(3.4445,4.7750,2.0315)

  1. 收敛性验证

    x = [0.01, 0.3, 0.6, 0.9, 1.0]
    phi_5 = [0.92, 0.97, 1.05, 1.16, 1.23]  # 5步迭代结果
    

    所有值均在 [ 0.7 , 1.3 ] [0.7,1.3] [0.7,1.3]

  2. 效率优化证明
    τ 90 % = ln ⁡ ( 0.1 ) ln ⁡ ( 1 / a ) = 2.97 ≈ 3 步 \tau_{90\%} = \frac{\ln(0.1)}{\ln(1/a)} = 2.97 \approx 3 \text{ 步} τ90%=ln(1/a)ln(0.1)=2.973 
    (小奇异值只需3步达90%收敛)

  3. 工程稳定性
    max ⁡ ∣ ϕ ′ ( x ) ∣ = 0.983 < 1 (满足压缩映射条件) \max |\phi'(x)| = 0.983 < 1 \quad \text{(满足压缩映射条件)} maxϕ(x)=0.983<1(满足压缩映射条件)


🔬 与基准系数对比

特性基准(2,-1.5,0.5)优化后(3.4445,-4.775,2.0315)
小奇异值收敛步数83(加速2.6倍)
最大偏离±0.25+0.23/-0.08
训练迭代步数85(降低37.5%)

该系数使得5步NS迭代即达工业级精度要求(误差<0.3),而基准方案需8步以上。


💡 设计哲学:面向硬件的系数优化

Muon系数的核心创新在于将数值分析转化为可微优化问题

  1. 收敛域松弛:允许30%误差换取3倍加速
  2. GPU指令感知
    • b b b值减少乘加指令(MAC)依赖
    • a a a最大化利用Tensor Core的FP16峰值
  3. 内存访问优化
    c = 2.0315 ≈ 2 ⇒ 可用移位指令替代乘法 c=2.0315 \approx 2 \quad \Rightarrow \quad \text{可用移位指令替代乘法} c=2.03152可用移位指令替代乘法

最终实现:5步迭代即达正交化目标,FLOP开销<1%,成为首个实用化梯度正交优化器。

参考文献

  1. Bernstein & Newhouse. Old Optimizer, New Norm (2024)
  2. Gupta et al. Shampoo: Preconditioned Stochastic Tensor Optimization (2018)
  3. Dubey et al. The LLaMA 3 Herd of Models (2024)
  4. Muon原始论文
http://www.lryc.cn/news/571315.html

相关文章:

  • 从零到一:C语言基础入门学习路线与核心知识点全解析
  • 香橙派3B学习笔记12:C语言操作GPIO_<wiringPi.h>_点灯通用输入输出
  • FPGA 44 ,SDC 时序约束标准( 深度解析 SDC 标准 )
  • 期末作业swing水果店管理系统
  • 二分算法深度解析
  • 简说 python
  • C++ vector(2)
  • 【编译工具】CodeRider 2.0:驭码 CodeRider 2.0 全流程智能研发协作平台深度技术测评报告
  • Java在IDEA中终端窗口输出正常,但打包成JAR后中文乱码问题
  • 『大模型笔记』第3篇:多长的 Prompt 会阻塞其他请求?优化策略解析
  • Java线程池全面解析:原理、实现与最佳实践
  • Socket 编程 UDP
  • 【Linux】UDP与TCP协议
  • Kubernetes RDMA 概述与实战(大模型场景)
  • UE5 游戏模板 —— Puzzle 拼图游戏
  • 【配置教程】新版OpenCV+Android Studio环境配置(4.11测试通过)
  • 在线教学课程视频AI智能大纲代码与演示
  • 【Docker安装PostgreSQL】psql:致命错误: 用户 Password 认证失败
  • 在 MongoDB 中复制一个 collection(集合)
  • 以下是系统化的 Python基础学习框架,分为4个核心阶段,结合理论与实践,适合零基础快速入门并建立扎实的编程基础:
  • 【WPF】WPF ComboBox 数据驱动不刷新?SelectedItem 与 SelectedIndex 解析!
  • 什么是数据仓库的ETL
  • TortoiseSVN迁移到本地git
  • Tomcat 核心配置解析:4 大文件、乱码处理、端口与 Manager 配置
  • 企业ERP致胜秘籍:从流程革新到智能决策
  • 关系数据库-数据库事务处理与ACID原则
  • Android 开发问题:CardView 的阴影效果会受到父容器的裁切
  • STM32 实现解析自定义协议
  • HTTP 请求中的 `Content-Type` 类型详解及前后端示例(Vue + Spring Boot)
  • 为什么您应该停止使用 1080 玻璃