GLU 变种:ReGLU 、 GEGLU 、 SwiGLU
文章目录
- GLU 变种:ReGLU 、 GEGLU 、 SwiGLU
- 1. ReGLU(ReLU-GLU)
- 函数表达式
- 代码
- 2. GEGLU(Gaussian Error GLU)
- 函数表达式
- 代码
- 3. SwiGLU(Swish-GLU)
- 函数表达式
- 代码
- 合并代码
GLU 变种:ReGLU 、 GEGLU 、 SwiGLU
-
在 GLU 的基础上,陆续提出了若干“激活 + GLU”的混合门控单元。它们共享同一套“双线形投影 + 逐元素门控”范式,差别仅在于把 GLU 中的 Sigmoid 门控替换为其他非线性函数,从而在参数量几乎不变的前提下带来不同的归纳偏差与性能收益。
-
参考论文:GLU Variants Improve Transformer
https://arxiv.org/pdf/2002.05202
1. ReGLU(ReLU-GLU)
- 核心思想:把 Sigmoid 换成 ReLU,让门控也具备稀疏性,计算更便宜,且保留 GLU 的残差特性。
函数表达式
ReGLU(x)=(xW+b)⊗ReLU(xV+c)\text{ReGLU}(x) = (xW+b)\,\otimes\,\text{ReLU}(xV+c) ReGLU(x)=(xW+b)⊗ReLU(xV+c)
代码
-
代码
import torch from torch import nnclass ReGLU(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.w_gate = nn.Linear(d_in, d_out, bias=False)self.w_up = nn.Linear(d_in, d_out, bias=False)self.w_down = nn.Linear(d_out, d_in, bias=False)def forward(self, x):gate = F.relu(self.w_gate(x))up = self.w_up(x)return self.w_down(gate * up)
2. GEGLU(Gaussian Error GLU)
- 核心思想:用 GELU 取代 Sigmoid,兼顾稀疏与平滑,兼顾 ReLU 的低计算与 Swish 的高表达。
函数表达式
GEGLU(x)=(xW+b)⊗GELU(xV+c)\text{GEGLU}(x) = (xW+b)\,\otimes\,\text{GELU}(xV+c) GEGLU(x)=(xW+b)⊗GELU(xV+c)
代码
-
代码
import torch from torch import nnclass GEGLU(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.w_gate = nn.Linear(d_in, d_out, bias=False)self.w_up = nn.Linear(d_in, d_out, bias=False)self.w_down = nn.Linear(d_out, d_in, bias=False)def forward(self, x):gate = F.gelu(self.w_gate(x))up = self.w_up(x)return self.w_down(gate * up)
3. SwiGLU(Swish-GLU)
- 核心思想:将 Swish 引入门控;Swish 本身具备 可学习/常数 β,在深层网络中表现优于 ReLU/GELU。
函数表达式
SwiGLU(x)=(xW+b)⊗Swishβ(xV+c)Swishβ(z)=z⋅σ(βz)\text{SwiGLU}(x) = (xW+b)\,\otimes\,\text{Swish}_\beta(xV+c) \\ \text{Swish}_\beta(z)=z\cdot\sigma(\beta z) SwiGLU(x)=(xW+b)⊗Swishβ(xV+c)Swishβ(z)=z⋅σ(βz)
代码
-
固定swish函数中的参数 β=1\beta = 1β=1 (SiLU)
import troch from torch import nnclass SwiGLU(nn.Module):def __init__(self, d_in, d_out, beta=1.0):super().__init__()self.beta = betaself.w_gate = nn.Linear(d_in, d_out, bias=False)self.w_up = nn.Linear(d_in, d_out, bias=False)self.w_down = nn.Linear(d_out, d_in, bias=False)def forward(self, x):gate = self.w_gate(x)gate = gate * torch.sigmoid(self.beta * gate) # Swishup = self.w_up(x)return self.w_down(gate * up)
合并代码
-
torch封装
import torch from torch import nnclass GLUVariants(nn.Module):def __init__(self, d_in, d_out, variant="geglu"):super().__init__()self.variant = variant.lower()self.w_gate = nn.Linear(d_in, d_out, bias=False)self.w_up = nn.Linear(d_in, d_out, bias=False)self.w_down = nn.Linear(d_out, d_in, bias=False)def forward(self, x):gate = self.w_gate(x)up = self.w_up(x)if self.variant == "reglu":gate = F.relu(gate)elif self.variant == "geglu":gate = F.gelu(gate)elif self.variant == "swiglu":gate = gate * torch.sigmoid(gate) # β=1else:gate = torch.sigmoid(gate) # fallback to GLUreturn self.w_down(gate * up)
输出
torch.Size([8, 64, 512])
-
对比
特性 GLU ReGLU GEGLU SwiGLU 门控激活 Sigmoid ReLU GELU Swish 稀疏门控 否 是 部分 平滑稀疏 计算量 中 低 中 中 梯度平滑性 中 差 好 最好 实际效果(大模型) 基线 接近 GLU 略优于 GLU 最佳 是否需额外参数 否 否 否 可选 β