一文读懂循环神经网络—门控循环单元
目录
重置门(Reset Gate)
定义
作用
数学公式
更新门(Update Gate)
定义
作用
数学公式
隐藏状态(Hidden State)
定义
数学表示
核心作用
候选隐藏状态(Candidate Hidden State)
定义
数学表示(以 GRU 为例)
核心作用
隐藏状态 vs 候选隐藏状态
直观理解
详解
为什么需要候选状态?
可视化流程(GRU 为例)
完整代码
实验结果
重置门(Reset Gate)
定义
重置门决定了如何将新的输入信息与之前的隐藏状态相结合。它可以 "重置" 历史隐藏状态的部分信息,允许模型有选择地遗忘过去。
作用
- 捕获短期依赖:通过控制过去信息的保留程度,帮助模型关注最近的输入。
- 防止梯度消失:允许梯度在需要时更有效地流动。
数学公式
是重置门输出
是上一时刻隐藏状态
为当前时刻输入
为权重矩阵
表示 sigmoid 函数
表示拼接操作
为偏置项
- 功能:决定如何 "重置" 历史隐藏状态,控制上一时刻的隐藏状态
对当前候选状态的影响程度。
- 输出范围:
,其中 0 表示完全忽略历史,1 表示保留全部历史。
更新门(Update Gate)
定义
更新门决定了新的隐藏状态中有多少来自过去的隐藏状态,以及多少来自当前输入的新信息。它类似于 LSTM 中的遗忘门和输入门的组合。
作用
- 捕获长期依赖:通过控制信息的更新程度,允许模型保留长期信息。
- 减少冗余计算:只有当更新门指示需要更新时,模型才会处理新输入。
数学公式
是更新门的输出(范围在 0 到 1 之间)- 其他符号含义与重置门相同
- 功能:决定新的隐藏状态中,有多少来自候选状态
,多少来自历史状态
。
- 输出范围:
,其中 0 表示完全使用新信息,1 表示完全保留历史信息。
隐藏状态(Hidden State)
定义
隐藏状态 是 RNN 在时间步 t 的内部表示,它融合了 历史信息 和 当前输入,并作为后续时间步的上下文。
数学表示
在标准 RNN 中:
:当前输入
:上一时刻的隐藏状态
:权重矩阵
:激活函数(如 tanh 或 ReLU)
核心作用
- 记忆功能:通过
传递历史信息,使模型能够处理序列中的长期依赖。
- 上下文整合:将历史信息与当前输入结合,形成对序列的动态理解。
候选隐藏状态(Candidate Hidden State)
定义
候选隐藏状态 是 临时计算的中间状态,用于生成下一时刻的实际隐藏状态
。它在门控循环单元(如 LSTM、GRU)中尤为重要。
数学表示(以 GRU 为例)
:重置门输出
:元素级乘法
:激活函数,将输出约束在 \([-1, 1]\)
核心作用
- 信息筛选:通过重置门
选择性地保留历史信息,避免无关信息干扰。
- 生成新状态:
作为 "候选",需要经过更新门的调控才能成为最终的隐藏状态
。
隐藏状态 vs 候选隐藏状态
对比项 | 隐藏状态 ( | 候选隐藏状态 ( |
---|---|---|
角色 | 最终的上下文表示,传递到下一时刻 | 生成新隐藏状态的中间计算结果 |
是否门控 | 是(通过更新门 | 是(通过重置门 |
信息来源 | 整合了历史状态 | 基于当前输入 |
范围 | 由更新门 | 由 |
直观理解
详解
-
候选隐藏状态
: 可以看作是 "建议更新内容",它根据当前输入和部分历史信息提出一个 "候选",但需要经过更新门的批准才能生效。
-
隐藏状态
: 可以看作是 "历史记忆 + 新信息的融合",它通过更新门权衡历史与当前的重要性,决定最终保留哪些信息。
- 重置门:类似于 "遗忘开关",决定是否忽略历史隐藏状态。当
≈ 0
时,模型几乎完全忽略历史,专注于当前输入。 - 更新门:类似于 "记忆开关",决定是否保留历史隐藏状态。当
时,模型主要使用新信息;当≈ 0
时,主要保留历史信息。≈ 1
为什么需要候选状态?
门控机制(如 GRU 的重置门和更新门)的核心目的是 选择性地记忆和遗忘:
- 重置门 通过
控制历史信息的哪些部分参与生成 \(\tilde{h}_t\),帮助模型关注短期信息。
- 更新门 通过
控制对
的影响程度,帮助模型保留长期信息。
这种设计使 RNN 能够有效处理 梯度消失 和 长期依赖 问题。
可视化流程(GRU 为例)
plaintext
输入序列: x_1 → x_2 → x_3 → ... → x_t1. 计算重置门:r_t = σ(W_r·[h_{t-1}, x_t] + b_r)2. 计算候选隐藏状态:h̃_t = tanh(W·[r_t⊙h_{t-1}, x_t] + b) # 基于部分历史和当前输入3. 计算更新门:z_t = σ(W_z·[h_{t-1}, x_t] + b_z)4. 更新隐藏状态:h_t = (1-z_t)⊙h̃_t + z_t⊙h_{t-1} # 融合候选状态和历史状态
完整代码
"""
文件名: 9.1
作者: 墨尘
日期: 2025/7/15
项目名: dl_env
备注: 基于GRU(门控循环单元)的字符级文本生成模型,以《时间机器》文本为训练数据
"""
# 基础工具库
import collections # 用于统计词频
import random # 随机抽样
import re # 文本清洗(正则表达式)
import requests # 下载数据集
from pathlib import Path # 文件路径处理
from d2l import torch as d2l # 深度学习工具库
import math # 数学运算
import torch # PyTorch框架
from torch import nn # 神经网络模块
from torch.nn import functional as F # 函数式API# 图像显示相关库(解决中文和符号显示问题)
import matplotlib.pyplot as plt
import matplotlib.text as text# -------------------------- 核心解决方案:解决文本显示问题 --------------------------
def replace_minus(s):"""解决Matplotlib中Unicode减号(U+2212)显示为方块的问题原理:将特殊减号替换为普通ASCII减号('-')"""if isinstance(s, str): # 仅处理字符串return s.replace('\u2212', '-') # 替换特殊减号return s # 非字符串直接返回# 重写matplotlib的Text类的set_text方法,全局生效
original_set_text = text.Text.set_text # 保存原始方法
def new_set_text(self, s):s = replace_minus(s) # 先处理减号return original_set_text(self, s) # 调用原始方法设置文本
text.Text.set_text = new_set_text # 应用重写后的方法# -------------------------- 字体配置(确保中文和数学符号正常显示)--------------------------
plt.rcParams["font.family"] = ["SimHei"] # 设置中文字体(支持中文显示)
plt.rcParams["text.usetex"] = True # 使用LaTeX渲染文本(提升数学符号美观度)
plt.rcParams["axes.unicode_minus"] = True # 确保负号正确显示(避免方块)
plt.rcParams["mathtext.fontset"] = "cm" # 数学符号使用Computer Modern字体(LaTeX标准字体)
d2l.plt.rcParams.update(plt.rcParams) # 让d2l库的绘图工具继承上述配置# -------------------------- 1. 读取数据 --------------------------
def read_time_machine():"""下载并读取《时间机器》数据集,返回清洗后的文本行列表"""data_dir = Path('./data') # 数据存储目录data_dir.mkdir(exist_ok=True) # 目录不存在则创建file_path = data_dir / 'timemachine.txt' # 数据集文件路径# 检查文件是否存在,不存在则下载if not file_path.exists():print("开始下载时间机器数据集...")# 从d2l官方地址下载文本response = requests.get('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')# 写入文件(utf-8编码)with open(file_path, 'w', encoding='utf-8') as f:f.write(response.text)print(f"数据集下载完成,保存至: {file_path}")# 读取文件并清洗文本with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines() # 按行读取print(f"文件读取成功,总行数: {len(lines)}")if len(lines) > 0:print(f"第一行内容: {lines[0].strip()}") # 打印首行验证# 清洗规则:保留字母,其他字符替换为空格,转小写,去除首尾空格cleaned_lines = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines if line.strip()]print(f"清洗后有效行数: {len(cleaned_lines)}") # 清洗后非空行数量return cleaned_lines# -------------------------- 2. 词元化与词表构建 --------------------------
def tokenize(lines, token='char'):"""将文本行转换为词元列表(词元是文本的最小处理单位)参数:lines: 清洗后的文本行列表(如["abc def", "ghi jkl"])token: 词元类型('char'字符级/'word'单词级)返回:词元列表(如字符级:[['a','b','c',' ','d','e','f'], ...])"""if token == 'char':# 字符级词元化:将每行拆分为单个字符列表return [list(line) for line in lines]elif token == 'word':# 单词级词元化:按空格拆分每行(需确保文本已用空格分隔单词)return [line.split() for line in lines]else:raise ValueError('未知词元类型:' + token)class Vocab:"""词表类:实现词元与索引的双向映射,用于将文本转换为模型可处理的数字序列"""def __init__(self, tokens, min_freq=0, reserved_tokens=None):"""构建词表参数:tokens: 词元列表(可嵌套,如[[token1, token2], [token3]])min_freq: 最低词频阈值(低于此值的词元不加入词表)reserved_tokens: 预留特殊词元(如分隔符、填充符等)"""if reserved_tokens is None:reserved_tokens = [] # 默认为空# 统计词频:展平嵌套列表,用Counter计数counter = collections.Counter([token for line in tokens for token in line])# 按词频降序排序(便于后续按频率筛选)self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表:<unk>(未知词元)固定在索引0, followed by预留词元self.idx_to_token = ['<unk>'] + reserved_tokens# 构建词元到索引的映射(字典)self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 按词频添加词元(过滤低频词)for token, freq in self.token_freqs:if freq < min_freq:break # 低频词不加入词表if token not in self.token_to_idx: # 避免重复添加预留词元self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1 # 索引为当前长度-1def __len__(self):"""返回词表大小(词元总数)"""return len(self.idx_to_token)def __getitem__(self, tokens):"""词元→索引(支持单个词元或词元列表)未知词元返回<unk>的索引(0)"""if not isinstance(tokens, (list, tuple)):# 单个词元:查字典,默认返回<unk>的索引return self.token_to_idx.get(tokens, self.unk)# 词元列表:递归转换每个词元return [self.__getitem__(token) for token in tokens]def to_tokens(self, indices):"""索引→词元(支持单个索引或索引列表)"""if not isinstance(indices, (list, tuple)):# 单个索引:直接查列表return self.idx_to_token[indices]# 索引列表:递归转换每个索引return [self.idx_to_token[index] for index in indices]@propertydef unk(self):"""返回<unk>的索引(固定为0)"""return 0# -------------------------- 3. 数据迭代器(随机抽样) --------------------------
def seq_data_iter_random(corpus, batch_size, num_steps):"""随机抽样生成批量子序列(生成器),用于模型训练的批量输入原理:从语料中随机截取多个长度为num_steps的子序列,组成批次参数:corpus: 词元索引序列(1D列表,如[1,3,5,2,...])batch_size: 批量大小(每个批次包含的子序列数)num_steps: 子序列长度(时间步,即模型一次处理的序列长度)返回:生成器,每次返回(X, Y):X: 输入序列(batch_size, num_steps)Y: 标签序列(batch_size, num_steps),是X右移一位的结果"""# 检查数据是否足够生成至少一个子序列(子序列长度+1,因Y是X右移1位)if len(corpus) < num_steps + 1:raise ValueError(f"语料库长度({len(corpus)})不足,需至少{num_steps+1}")# 随机偏移起始位置(0到num_steps-1),增加数据随机性corpus = corpus[random.randint(0, num_steps - 1):]# 计算可生成的子序列总数:(语料长度-1) // num_steps(-1是因Y需多1个元素)num_subseqs = (len(corpus) - 1) // num_stepsif num_subseqs < 1:raise ValueError(f"无法生成子序列(语料库长度不足)")# 生成所有子序列的起始索引(间隔为num_steps)initial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices) # 打乱起始索引,实现随机抽样# 计算可生成的批次数:子序列总数 // 批量大小num_batches = num_subseqs // batch_sizeif num_batches < 1:raise ValueError(f"子序列数量({num_subseqs})不足,需至少{batch_size}个")# 生成批量数据for i in range(0, batch_size * num_batches, batch_size):# 当前批次的起始索引(从打乱的索引中取batch_size个)indices = initial_indices[i: i + batch_size]# 输入序列X:每个子序列从indices[j]开始,取num_steps个元素X = [corpus[j: j + num_steps] for j in indices]# 标签序列Y:每个子序列从indices[j]+1开始,取num_steps个元素(X右移1位)Y = [corpus[j + 1: j + num_steps + 1] for j in indices]# 转换为张量返回(便于模型处理)yield torch.tensor(X), torch.tensor(Y)# -------------------------- 4. 数据加载函数(关键修复:返回可重置的迭代器) --------------------------
def load_data_time_machine(batch_size, num_steps):"""加载《时间机器》数据,返回数据迭代器生成函数和词表修复点:返回迭代器生成函数(而非一次性迭代器),确保训练时可重复生成数据参数:batch_size: 批量大小num_steps: 子序列长度(时间步)返回:data_iter: 迭代器生成函数(调用时返回新的迭代器)vocab: 词表对象"""lines = read_time_machine() # 读取清洗后的文本行tokens = tokenize(lines, token='char') # 字符级词元化(每个字符为词元)vocab = Vocab(tokens) # 构建词表# 将所有词元转换为索引(展平为1D序列)corpus = [vocab[token] for line in tokens for token in line]print(f"语料库长度: {len(corpus)}(词元索引总数)")# 定义迭代器生成函数:每次调用生成新的随机抽样迭代器def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)return data_iter, vocab # 返回生成函数和词表# -------------------------- 5. GRU模型核心实现 --------------------------
def get_params(vocab_size, num_hiddens, device):"""初始化GRU模型参数包含:更新门、重置门、候选隐状态、输出层的权重和偏置参数:vocab_size: 词表大小(输入/输出维度)num_hiddens: 隐藏层维度(隐状态维度)device: 计算设备(CPU/GPU)返回:参数列表:[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]"""num_inputs = num_outputs = vocab_size # 输入/输出维度=词表大小# 正态分布初始化参数(均值0,标准差0.01)def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 生成三组参数(权重1、权重2、偏置),用于门控机制def three():return (normal((num_inputs, num_hiddens)), # 输入→隐藏权重normal((num_hiddens, num_hiddens)), # 隐藏→隐藏权重torch.zeros(num_hiddens, device=device)) # 偏置(初始化为0)# 更新门(Update Gate)参数:W_xz(输入→更新门)、W_hz(隐藏→更新门)、b_z(偏置)W_xz, W_hz, b_z = three()# 重置门(Reset Gate)参数:W_xr(输入→重置门)、W_hr(隐藏→重置门)、b_r(偏置)W_xr, W_hr, b_r = three()# 候选隐状态(Candidate Hidden State)参数:W_xh(输入→候选隐状态)、W_hh(隐藏→候选隐状态)、b_h(偏置)W_xh, W_hh, b_h = three()# 输出层参数:W_hq(隐藏→输出)、b_q(偏置)W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 所有参数附加梯度(允许反向传播更新)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):"""初始化GRU的隐藏状态(全零向量)返回元组形式,便于扩展(如LSTM有两个状态)参数:batch_size: 批量大小num_hiddens: 隐藏层维度device: 计算设备返回:隐藏状态元组:(H,),其中H形状为(batch_size, num_hiddens)"""return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):"""GRU前向传播(逐时间步计算)参数:inputs: 输入序列(num_steps, batch_size, vocab_size),已转换为one-hot编码state: 初始隐藏状态(batch_size, num_hiddens)params: GRU参数列表(见get_params)返回:outputs: 所有时间步的输出(num_steps*batch_size, vocab_size)state: 最终隐藏状态(batch_size, num_hiddens)"""# 解析参数W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state # 初始隐藏状态(从元组中取出)outputs = [] # 存储每个时间步的输出# 逐时间步计算for X in inputs: # X形状:(batch_size, vocab_size)(当前时间步的输入)# 1. 计算更新门 Z_t = σ(X_t·W_xz + H_{t-1}·W_hz + b_z)Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 2. 计算重置门 R_t = σ(X_t·W_xr + H_{t-1}·W_hr + b_r)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)# 3. 计算候选隐状态 Ĥ_t = tanh(X_t·W_xh + (R_t ⊙ H_{t-1})·W_hh + b_h)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)# 4. 计算最终隐状态 H_t = Z_t ⊙ H_{t-1} + (1-Z_t) ⊙ Ĥ_tH = Z * H + (1 - Z) * H_tilda# 5. 计算输出 Y_t = H_t·W_hq + b_qY = H @ W_hq + b_qoutputs.append(Y) # 保存当前时间步的输出# 拼接所有时间步的输出(形状:(num_steps*batch_size, vocab_size)),返回输出和最终状态return torch.cat(outputs, dim=0), (H,)# -------------------------- 6. RNN模型包装类 --------------------------
class RNNModelScratch: #@save"""从零实现的RNN模型包装类,统一模型调用接口"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):"""参数:vocab_size: 词表大小(输入/输出维度)num_hiddens: 隐藏层维度device: 计算设备get_params: 参数初始化函数(如get_params)init_state: 状态初始化函数(如init_gru_state)forward_fn: 前向传播函数(如gru)"""self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device) # 模型参数self.init_state, self.forward_fn = init_state, forward_fn # 状态初始化和前向传播函数def __call__(self, X, state):"""模型调用接口(前向传播入口)参数:X: 输入序列(batch_size, num_steps),元素为词元索引state: 初始隐藏状态返回:y_hat: 输出(num_steps*batch_size, vocab_size)state: 最终隐藏状态"""# 处理输入:# 1. X.T:转置为(num_steps, batch_size)(便于逐时间步处理)# 2. F.one_hot:转换为one-hot编码(num_steps, batch_size, vocab_size)# 3. type(torch.float32):转换为浮点型(适配后续矩阵运算)X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 调用前向传播函数return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):"""获取初始隐藏状态(调用初始化函数)"""return self.init_state(batch_size, self.num_hiddens, device)# -------------------------- 7. 预测函数(文本生成) --------------------------
def predict_ch8(prefix, num_preds, net, vocab, device): #@save"""根据前缀生成后续字符(文本生成)参数:prefix: 前缀字符串(如"time traveller")num_preds: 要生成的字符数net: 训练好的GRU模型vocab: 词表device: 计算设备返回:生成的字符串(前缀+预测字符)"""# 初始化状态(批量大小为1,因仅生成一条序列)state = net.begin_state(batch_size=1, device=device)# 记录输出索引:初始为前缀首字符的索引outputs = [vocab[prefix[0]]]# 辅助函数:获取当前输入(最后一个输出的索引,形状(1,1))def get_input():return torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 预热期:用前缀更新模型状态(不生成新字符,仅让模型"记住"前缀)for y in prefix[1:]:_, state = net(get_input(), state) # 前向传播,更新状态(忽略输出)outputs.append(vocab[y]) # 记录前缀字符的索引# 预测期:生成num_preds个字符for _ in range(num_preds):y, state = net(get_input(), state) # 前向传播,获取输出和新状态# 取概率最大的字符索引(贪婪采样)outputs.append(int(y.argmax(dim=1).reshape(1)))# 将索引转换为字符,拼接成字符串返回return ''.join([vocab.idx_to_token[i] for i in outputs])# -------------------------- 8. 梯度裁剪(防止梯度爆炸) --------------------------
def grad_clipping(net, theta): #@save"""裁剪梯度(将梯度L2范数限制在theta内),防止梯度爆炸参数:net: 模型(自定义模型或nn.Module)theta: 梯度阈值"""# 获取需要梯度更新的参数if isinstance(net, nn.Module):# 若为PyTorch官方Module,直接取parametersparams = [p for p in net.parameters() if p.requires_grad]else:# 若为自定义模型,取params属性params = net.params# 计算所有参数梯度的L2范数norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta: # 若范数超过阈值,按比例裁剪for param in params:param.grad[:] *= theta / norm# -------------------------- 9. 训练函数 --------------------------
def train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter):"""训练一个周期(单轮遍历数据集)参数:net: GRU模型train_iter_fn: 迭代器生成函数(调用后返回新迭代器)loss: 损失函数(如CrossEntropyLoss)updater: 优化器(如SGD)device: 计算设备use_random_iter: 是否使用随机抽样(影响状态处理)返回:ppl: 困惑度(perplexity,衡量语言模型性能,越低越好)speed: 训练速度(词元/秒)"""state, timer = None, d2l.Timer() # 初始化状态和计时器metric = d2l.Accumulator(2) # 累加器:(总损失, 总词元数)batches_processed = 0 # 记录处理的批次数量# 关键修复:每次训练都通过函数生成新的迭代器(避免迭代器被提前消费)train_iter = train_iter_fn()# 遍历批量数据for X, Y in train_iter:batches_processed += 1# 初始化状态:# - 首次迭代时需初始化# - 随机抽样时,每个批次的状态独立,需重新初始化if state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:# 非随机抽样时,分离状态(切断梯度回流到之前的批次,避免梯度计算依赖过长)if isinstance(net, nn.Module) and not isinstance(state, tuple):state.detach_() # 单个状态直接detachelse:for s in state: # 多个状态(如LSTM)逐个detachs.detach_()# 处理标签:# Y.T.reshape(-1):转置后展平为(num_steps*batch_size,)(与输出形状匹配)y = Y.T.reshape(-1)# 将输入和标签移到目标设备X, y = X.to(device), y.to(device)# 前向传播:获取输出和新状态y_hat, state = net(X, state)# 计算损失(mean()是因损失函数可能返回每个样本的损失)l = loss(y_hat, y.long()).mean()# 反向传播与参数更新:if isinstance(updater, torch.optim.Optimizer):# 若为PyTorch优化器(如SGD)updater.zero_grad() # 清零梯度l.backward() # 反向传播grad_clipping(net, 1) # 裁剪梯度(阈值1)updater.step() # 更新参数else:# 若为自定义优化器l.backward()grad_clipping(net, 1)updater(batch_size=1) # 假设批量大小为1的更新# 累加总损失和总词元数(用于计算平均损失)metric.add(l * y.numel(), y.numel())# 检查是否有批次被处理(避免空迭代)if batches_processed == 0:print("警告:没有处理任何训练批次!")return float('inf'), 0# 计算困惑度(perplexity = exp(平均损失))和训练速度(词元/秒)return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()def train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=False):"""训练模型(多周期)参数:net: GRU模型train_iter_fn: 迭代器生成函数vocab: 词表lr: 学习率num_epochs: 训练周期数device: 计算设备use_random_iter: 是否使用随机抽样(默认False)"""loss = nn.CrossEntropyLoss() # 交叉熵损失(适用于分类任务,此处为字符预测)# 动画器:可视化训练过程(困惑度随周期变化)animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化优化器:if isinstance(net, nn.Module):# 若为PyTorch Module,使用SGD优化器updater = torch.optim.SGD(net.parameters(), lr)else:# 若为自定义模型,使用d2l的sgd函数updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)# 定义预测函数:根据前缀"time traveller"生成50个字符predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 多周期训练for epoch in range(num_epochs):# 训练一个周期,返回困惑度和速度ppl, speed = train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter)# 每10个周期打印一次预测结果(观察生成文本质量变化)if (epoch + 1) % 10 == 0:print(f"epoch {epoch+1} 预测: {predict('time traveller')}")animator.add(epoch + 1, [ppl]) # 记录困惑度# 训练结束后输出最终结果print(f'最终困惑度 {ppl:.1f}, 速度 {speed:.1f} 词元/秒 {device}')print(f"time traveller 预测: {predict('time traveller')}")print(f"traveller 预测: {predict('traveller')}")# -------------------------- 主程序 --------------------------
if __name__ == '__main__':# 超参数设置batch_size, num_steps = 32, 35 # 批量大小=32,时间步=35# 加载数据:获取迭代器生成函数和词表train_iter, vocab = load_data_time_machine(batch_size, num_steps)# 模型参数vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() # 词表大小、隐藏层维度、自动选择GPU/CPUnum_epochs, lr = 500, 0.12 # 训练周期=500,学习率=0.12# 初始化GRU模型model = RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)# 开始训练train_ch8(model, train_iter, vocab, lr, num_epochs, device)plt.show(block=True) # 显示训练过程的动画图(阻塞模式,确保图不闪退)
实验结果