当前位置: 首页 > news >正文

【第四章:大模型(LLM)】01.神经网络中的 NLP-(2)Seq2Seq 原理及代码解析

第四章:大模型(LLM)

第二部分:神经网络中的 NLP

第二节:Seq2Seq 原理及代码解析

1. Seq2Seq(Sequence-to-Sequence)模型原理

Seq2Seq 是一种处理序列到序列任务(如机器翻译、文本摘要、对话生成等)的深度学习架构,最早由 Google 在 2014 年提出。其核心思想是使用 编码器(Encoder) 将输入序列编码为上下文向量,再通过 解码器(Decoder) 逐步生成输出序列。

1.1 架构组成

  1. 编码器(Encoder)

    • 通常是 RNN、LSTM 或 GRU。

    • 输入:序列 x = (x_1, x_2, ..., x_T)

    • 输出:隐藏状态 h_T​,作为上下文向量。

  2. 解码器(Decoder)

    • 结构类似于编码器。

    • 输入:编码器输出的上下文向量 + 上一步预测的输出。

    • 输出:目标序列 y = (y_1, y_2, ..., y_T)

  3. 上下文向量(Context Vector)

    • 编码器最后一个隐藏状态 h_T​ 作为整个输入序列的信息摘要。


2. 数学公式

  • 编码器:

h_t = f(h_{t-1}, x_t)

  • 解码器:

s_t = f(s_{t-1}, y_{t-1}, c)
P(y_t|y_{<t}, x) = \text{softmax}(W s_t)

其中 c 是上下文向量。


3. 经典 Seq2Seq 训练流程

  1. 输入序列通过编码器,生成上下文向量。

  2. 解码器利用上下文向量和前一时刻的预测结果,逐步生成输出。

  3. 使用 教师强制(Teacher Forcing) 技术,训练时将真实标签输入解码器。


4. 改进:Attention 机制

Seq2Seq 传统模型存在 长序列信息丢失 问题。
Attention 通过在每一步解码时为输入序列不同部分分配权重,解决了这个问题。
公式:

c_t = \sum_{i=1}^{T_x} \alpha_{t,i} h_i

其中 \alpha_{t,i}​ 是注意力权重。


5. PyTorch 代码解析:Seq2Seq 示例

import torch
import torch.nn as nn
import torch.optim as optim# Encoder
class Encoder(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers=1):super(Encoder, self).__init__()self.rnn = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)def forward(self, x):outputs, hidden = self.rnn(x)return hidden# Decoder
class Decoder(nn.Module):def __init__(self, output_dim, hidden_dim, num_layers=1):super(Decoder, self).__init__()self.rnn = nn.GRU(output_dim, hidden_dim, num_layers, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x, hidden):output, hidden = self.rnn(x, hidden)pred = self.fc(output)return pred, hidden# Seq2Seq
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, src, trg):hidden = self.encoder(src)outputs, _ = self.decoder(trg, hidden)return outputs# Example usage
input_dim, output_dim, hidden_dim = 10, 10, 32
encoder = Encoder(input_dim, hidden_dim)
decoder = Decoder(output_dim, hidden_dim)
model = Seq2Seq(encoder, decoder)src = torch.randn(16, 20, input_dim)  # batch=16, seq_len=20
trg = torch.randn(16, 20, output_dim)
output = model(src, trg)
print(output.shape)  # [16, 20, 10]


6. 应用场景

  • 机器翻译(Google Translate)

  • 文本摘要(新闻摘要生成)

  • 对话系统(聊天机器人)

  • 语音识别(语音到文本)

http://www.lryc.cn/news/602248.html

相关文章:

  • 从0到500账号管理:亚矩阵云手机多开组队与虚拟定位实战指南
  • 【归并排序】排序数组(medium)
  • Rust/Tauri 优秀开源项目推荐
  • C/C++ 调用lua脚本,lua脚本调用另一个lua脚本
  • 最新的前端技术和趋势(2025)
  • Maven中的bom和父依赖
  • Nginx HTTP 反向代理负载均衡实验
  • YOLO11 改进、魔改|低分辨率自注意力机制LRSA ,提取全局上下文建模与局部细节,提升小目标、密集小目标的检测能力
  • 免费 SSL 证书申请简明教程,让网站实现 HTTPS 访问
  • ADAS测试:如何用自动化手段提升VV效率
  • 【CDA干货】金融超市电商App经营数据分析案例
  • unbuntn 22.04 coreutils文件系统故障
  • GaussDB as的用法
  • 亚马逊广告关键词优化:如何精准定位目标客户
  • MyBatis中#{}与${}的实战避坑指南
  • 性能测试-技术指标的含义和计算
  • Leetcode_242.有效的字母异位词
  • Apache Commons VFS:Java内存虚拟文件系统,屏蔽不同IO细节
  • python入门篇12-虚拟环境conda的安装与使用
  • 深入Go并发编程:Channel、Goroutine与Select的协同艺术
  • 博士申请 | 荷兰阿姆斯特丹大学 招收计算机视觉(CV)方向 全奖博士生
  • 达梦有多少个模式
  • 亚马逊地址关联暴雷:新算法下的账号安全保卫战
  • 四、计算机组成原理——第6章:总线
  • 基于Hadoop3.3.4+Flink1.17.0+FlinkCDC3.0.0+Iceberg1.5.0整合,实现数仓实时同步mysql数据
  • [VLDB 2025]面向Flink集群巡检的交叉对比学习异常检测
  • SVN与GIT的区别,分别使用与哪些管理场景?
  • Go-Elasticsearch Typed Client查询请求的两种写法强类型 Request 与 Raw JSON
  • 正则表达式 速查速记
  • 10、Docker Compose 安装 MySQL