当前位置: 首页 > news >正文

门控线性单元GLU (Gated Linear Unit)

文章目录

    • 门控线性单元GLU (Gated Linear Unit)
      • 函数表达式
      • 与 Swish 的对比
      • PyTorch 中的 GLU 实现
      • TensorFlow 中的 GLU 实现

门控线性单元GLU (Gated Linear Unit)

  • 论文

    https://arxiv.org/abs/1612.08083

  • 门控线性单元(GLU)最初在《Language Modeling with Gated Convolutional Networks》提出,设计灵感来自门控控制,通过引入门控操作来控制信息的流动,它巧妙地将线性变换门控机制结合起来,通过可学习的门控信号来控制信息流,可以看做是引入了一种动态的选择极值,以在模型中选择性地传递信息

  • GLU 的有效性来源于其直观的工作流程:

    1. 双线性变换: 首先,输入 x 会并行地进行两次独立的线性变换。一次生成“候选”输出 (xW+b),这是潜在需要传递的信息。
    2. 门控过滤: 同时,另一个线性变换 (xV+c) 的结果会通过 Sigmoid 函数生成一个介于 0 和 1 之间的门控信号。这个门像一个智能开关,决定了“候选”输出中每个维度的信息应该有多少被保留,有多少应该被抑制。
    3. 残差友好: 由于门控输出的平均值大约为 0.5,这使得 GLU 具有一种天然的“残差”特性,有助于缓解深度网络训练中的梯度消失问题

函数表达式

  • GLU函数
    GLU(x)=(xW+b)⊗σ(xV+c)\begin{aligned} \mathrm{GLU(x)}=(xW+b)\otimesσ(xV+c) \end{aligned} GLU(x)=(xW+b)σ(xV+c)

    其中

    • x∈Rdx \in \mathbb{R}^dxRd 为输入向量
    • W、V∈Rd×dW、V \in \mathbb{R}^{d \times d}WVRd×db、c∈Rdb、c \in \mathbb{R}^dbcRd 为可学习的权重矩阵与偏置向量
    • ⊗\otimes 表示逐元素乘积(哈达玛乘积)
    • σ(⋅)\sigma(\cdot)σ() 为 sigmoid 门控,将后面 xV+cxV+cxV+c 值压缩到 (0, 1) 区间,作为门控信号,决定信息通过比例

与 Swish 的对比

  • 与swish对比

    特性SwishGLU
    参数标量 β\betaβ(固定或可学习)全连接权重 WWW、偏置 bbb(可学习)
    门控方式输入自身经过 sigmoid 缩放输入经线性变换后再经 sigmoid 门控
    参数量每通道 0/1 个标量每通道 d+1d+1d+1 个参数
    计算复杂度低(一次 sigmoid)高(一次矩阵乘 + sigmoid)
    表达能力中等

PyTorch 中的 GLU 实现

  • 代码(以 nn.GLU 为例,针对通道维度切分)

    注意:使用官方的GLU函数,输出维度是减半的

    import torch
    import torch.nn as nntorch.manual_seed(1024)batch_size = 8
    seq_len = 64
    d_model = 512x = torch.randn(batch_size, seq_len, d_model)# 官方 GLU 沿指定维度将输入一分为二
    glu = nn.GLU(dim=-1)          # dim 指定切分维度
    out = glu(x)                  # 输出 [batch_size, seq_len//2, d_model]print("Input shape :", x.shape)
    print("Output shape:", out.shape)"""输出"""
    Input shape : torch.Size([8, 64, 512])
    Output shape: torch.Size([8, 64, 256])
    

    若输入通道维度(seq_len)为偶数,可直接使用 nn.GLU(dim=channel_dim),此时将输入均分两份:前一半做值、后一半做门控。

  • 自定义 GLU(任意线性映射 + 门控)

    注意:输出维度可以

    import torch
    import torch.nn as nn
    torch.manual_seed(1024)class GLU(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.w1 = nn.Linear(d_in, d_out, bias=False)self.w2 = nn.Linear(d_in, d_out, bias=False)self.w3 = nn.Linear(d_out, d_in, bias=False)  # 可选:再投影回 d_indef forward(self, x):# x: [batch_size, seq_len, d_model]gate = torch.sigmoid(self.w2(x))   # [batch_size, seq_len, d_in]out  = self.w1(x) * gate           # [batch_size, seq_len, d_out]return self.w3(out)                # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = torch.randn(batch_size, seq_len, d_model)
    layer = GLU(d_in=d_model, d_out=d_ff)
    print(layer(x).shape)   # 维度不变"""输出"""
    torch.Size([8, 64, 512])
    

TensorFlow 中的 GLU 实现

  • 代码(tf.keras 自定义层)

    import tensorflow as tfclass GLU(tf.keras.layers.Layer):"""典型 Transformer-FFN 中的 GLU 层:GLU(x) = (x W_gate) ⊙ σ(x W_up)  再投影回 d_in,维度不变"""def __init__(self, d_in, d_out, **kwargs):super().__init__(**kwargs)self.d_in = d_inself.d_out = d_out# 两路线性映射self.w_gate = tf.keras.layers.Dense(d_out, use_bias=False)self.w_up   = tf.keras.layers.Dense(d_out, use_bias=False)self.w_down = tf.keras.layers.Dense(d_in, use_bias=False)def call(self, x):gate = tf.nn.sigmoid(self.w_gate(x))   # [batch_size, seq_len, d_out]up   = self.w_up(x)                    # [batch_size, seq_len, d_out]return self.w_down(gate * up)          # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = tf.random.normal([batch_size, seq_len, d_model])
    glu = GLU(d_in=d_model, d_out=d_ff)
    print(glu(x).shape)   # (4, 64, 512)  维度不变"""输出"""
    (8, 64, 512)
    

http://www.lryc.cn/news/592929.html

相关文章:

  • Zabbix安装-Server
  • 暑期自学嵌入式——Day05补充(C语言阶段)
  • 百炼MCP与IoT实战(三):手搓自定义MCP Server与阿里云FC配置
  • 「Java案例」判断是否是闰年的方法
  • 【JS笔记】Java Script学习笔记
  • stm32f4 dma的一些问题
  • 20250718-4-Kubernetes 应用程序生命周期管理-Pod对象:实现机制_笔记
  • CAD 约束求解:核心技术原理、流程及主流框架快速解析
  • Python 使用期物处理并发(使用concurrent.futures模块下载)
  • TF-IDF(Term Frequency - Inverse Document Frequency)
  • 7.19 pq | 并查集模板❗|栈循环|数组转搜索树
  • SpringBoot项目创建,三层架构,分成结构,IOC,DI相关,@Resource与@Autowired的区别
  • 如何下载并安装AIGCPanel
  • Maven私服仓库,发布jar到私服仓库,依赖的版本号如何设置,规范是什么
  • 四、CV_GoogLeNet
  • LT8644EX-矩阵芯片-富利威
  • 麒麟操作系统unity适配
  • 【科研绘图系列】R语言绘制分组箱线图
  • 闭包的定义和应用场景
  • Nestjs框架: 基于TypeORM的多租户功能集成和优化
  • RPG59.玩家拾取物品三:可拾取物品的提示UI
  • 如何写python requests?
  • [特殊字符] Spring Boot 常用注解全解析:20 个高频注解 + 使用场景实例
  • Linux基础IO通关秘籍:从文件描述符到重定向
  • 龙虎榜——20250718
  • Redis高频面试题:利用I/O多路复用实现高并发
  • 服务端高并发方案设计
  • Linux操作系统之线程:分页式存储管理
  • ARINC818航空总线机载视频处理系统设计
  • stm32驱动双步进电机