当前位置: 首页 > news >正文

详解从零开始实现循环神经网络(RNN)

目录

循环神经网络(RNN)的工作原理

1. 单步迭代过程

2. 隐状态的作用

无隐状态的神经网络(Feedforward Neural Networks)

1. 结构与局限

2. 为什么不适合序列数据?

有隐状态的循环神经网络(Recurrent Neural Networks, RNN)

1. 核心思想:隐状态作为 “记忆”

2. 结构特点

有无隐状态的神经网络对比

基于 RNN 的字符级语言模型

1. 核心目标

2. 如何用 RNN 实现?

(1)数据准备:序列切割

(2)RNN 的预测过程

(3)训练目标

困惑度(Perplexity):语言模型的评估指标

1. 定义与公式

2. 直观意义

3. 示例

什么是独热编码(One-Hot Encoding)

1.核心特点

token 值的设定问题

1. 为什么 token 可以等于 'char' 或 'word'?

2. 为什么需要区分这两种方式?

3. 本质:参数值是 “功能开关”

为什么需要生成函数?

1. 关键区别

方式一:返回迭代器生成函数

方式二:直接返回迭代器

2. 实际应用对比

使用生成函数

直接使用迭代器

3. 总结

 完整代码

实验结果

当学习率为0.1时

当学习率为0.0005时


循环神经网络(RNN)的工作原理

RNN 的 “循环” 体现在对序列的逐时间步迭代中,以下以 “文本序列” 为例说明其工作流程:

1. 单步迭代过程

假设处理文本序列x_1, x_2, ..., x_T(如字符 “a, b, c, ...”),每个x_t是t时刻的输入(如一个字符的嵌入向量)。

RNN 的迭代步骤为:

  1. 初始化隐状态h_0(通常为全 0 向量,代表 “初始记忆”);
  2. 第 1 步(t=1)
    • 输入x_1,结合初始隐状态h_0,计算新隐状态:h_1 = f(x_1, h_0)
    • 输出y_1 = g(h_1)(如预测下一个字符的概率分布);
  3. 第 2 步(t=2)
    • 输入x_2,结合上一步的隐状态h_1,计算\h_2 = f(x_2, h_1)
    • 输出y_2 = g(h_2)
  4. 以此类推,直到序列结束(t=T)。
2. 隐状态的作用

隐状态h_t是 “压缩的历史信息”—— 它整合了从x_1x_t的所有输入信息。例如:

 
  • 处理文本 “我在吃____” 时,h_t会 “记住”“我在吃” 的语义,从而帮助预测下一个词(如 “饭”“苹果”)。

无隐状态的神经网络(Feedforward Neural Networks)

无隐状态的神经网络(如多层感知机 MLP)是不具备 “记忆” 能力的网络,其核心特点是:输入与输出之间是 “瞬时映射”即当前输出仅依赖于当前输入,与历史输入无关

1. 结构与局限
  • 结构:输入层→隐藏层(可选)→输出层,层与层之间仅存在前向连接,无循环或反馈连接。
  • 局限:无法处理序列数据(如文本、时间序列)。例如:
    • 预测句子 “我吃了____” 的下一个词(如 “饭”)时,需要依赖前文 “我吃了” 的语义,但无隐状态的网络无法 “记住” 前文信息,只能孤立处理每个输入。
    • 处理时间序列(如股票价格)时,无法利用过去的价格波动规律预测未来,因为每个时间步的输入被视为独立样本。
2. 为什么不适合序列数据?

序列数据的核心是时序依赖(当前数据与历史数据相关),而无隐状态的网络会 “遗忘” 所有历史输入,因此无法捕捉这种依赖关系。

有隐状态的循环神经网络(Recurrent Neural Networks, RNN)

为解决序列数据的时序依赖问题,循环神经网络(RNN)引入了隐状态(Hidden State),使其具备 “记忆” 能力 —— 当前输出不仅依赖于当前输入,还依赖于历史输入的 “记忆”(即隐状态)

1. 核心思想:隐状态作为 “记忆”

隐状态(通常用h_t表示)是 RNN 的 “记忆载体”,其更新规则为: h_t = f(x_t, h_{t-1})其中:

 
  • x_t是t时刻的输入;
  • h_{t-1}是t-1时刻的隐状态(历史 “记忆”);
  • f是激活函数(如 tanh、ReLU),用于融合当前输入与历史记忆。

输出y_t则依赖于当前隐状态: y_t = g(h_t)(g是输出层函数,如 softmax 用于分类)

2. 结构特点
  • 循环连接:隐藏层的输出会反馈到自身(即h_{t}依赖h_{t-1}),形成 “循环” 结构(见下图简化示意)。
    输入x₁ → 隐状态h₁ → 输出y₁  ↓  
    输入x₂ → 隐状态h₂ → 输出y₂  ↓  
    输入x₃ → 隐状态h₃ → 输出y₃  
    ...
    

  • 动态记忆:隐状态h_t随时间步更新,不断 “吸收” 新输入并 “保留” 关键历史信息(如文本中的上下文语义)。

有无隐状态的神经网络对比

维度无隐状态的神经网络有隐状态的 RNN
记忆能力无记忆(输入独立)有记忆(依赖历史隐状态)
序列数据处理能力无法处理时序依赖可捕捉时序依赖
核心变量仅依赖当前输入x_t依赖x_t和历史隐状态h_{t-1}

基于 RNN 的字符级语言模型

语言模型的核心任务是:给定一段序列(如前文),预测下一个元素(如词、字符)的概率。字符级语言模型以 “字符” 为基本单位(而非词),即通过前文的字符预测下一个字符。

1. 核心目标

例如,给定字符序列 “hell”,模型需输出下一个字符为 “o” 的概率(P(o|h,e,l,l)),且该概率应高于其他字符(如 “a”“b”)。

2. 如何用 RNN 实现?
(1)数据准备:序列切割

将原始文本转换为输入 - 标签对。例如,对文本 “hello”,若设定序列长度为 4(即通过前 4 个字符预测第 5 个):

  • 输入X:“h,e,l,l”
  • 标签Y:“e,l,l,o”

(实际中会用更长的序列长度,通过滑动窗口生成大量样本,如你提供的seq_data_iter_random函数就是用于生成此类样本)

(2)RNN 的预测过程
  • 输入X的每个字符x_t(如 “h”“e”“l”“l”)依次传入 RNN;
  • RNN 通过隐状态h_t“记住” 前文信息(如 “h→e→l→l” 的序列特征);
  • 最后一个隐状态h_4包含 “hell” 的完整信息,基于h_4输出下一个字符 “o” 的概率分布。
(3)训练目标

通过交叉熵损失函数优化 RNN 参数,使预测的字符概率分布尽可能接近真实标签(即让P(o|h,e,l,l)尽可能大)。

困惑度(Perplexity):语言模型的评估指标

困惑度(Perplexity,简称 PPL)是衡量语言模型预测能力的核心指标,它本质上是 “模型对序列的平均惊讶度”—— 值越小,模型预测越准确。

1. 定义与公式

对于一个长度为T的序列x_1, x_2, ..., x_T,模型对该序列的困惑度定义为:\text{Perplexity} = \exp\left(-\frac{1}{T} \sum_{t=1}^T \log P(x_t | x_1, ..., x_{t-1})\right)

其中:

  • P(x_t | x_1, ..., x_{t-1})是模型预测第t个字符的概率(基于前文);
  • 对数通常取自然对数(\exp为指数函数);
  • 公式可简化为:困惑度 = 交叉熵损失的指数(因为交叉熵H = -\frac{1}{T}\sum \log P(x_t|...),故\text{PPL} = \exp(H)
2. 直观意义
  • 完美模型:若模型能 100% 准确预测每个字符(P(x_t|...) = 1),则\log P = 0,困惑度= \exp(0) = 1(最优)。
  • 随机猜测:若词汇表大小为V(如 26 个英文字母 + 标点),随机猜测时每个字符的概率为1/V,则困惑度= V(最差)。
  • 实际模型:困惑度介于 1 和V之间,值越小说明模型对序列的 “预测信心” 越高。
3. 示例
  • 模型 A 对 “hello” 的困惑度为 2.1,模型 B 为 5.3 → 模型 A 更好(更 “不困惑”)。
  • 训练过程中,若困惑度持续下降,说明模型在学习序列规律;若停滞,则可能过拟合或未充分训练。

什么是独热编码(One-Hot Encoding)

独热编码是一种将离散型特征(如类别、标签、词语等)转换为二进制向量的编码方式,其核心思想是用一个只有一个元素为 1、其余元素为 0 的向量来表示一个离散值。

1.核心特点
  • 向量长度等于离散特征的类别总数(或词汇表大小)。
  • 每个离散值对应向量中的一个唯一位置,该位置值为 1,其他位置值为 0。
  • 不同离散值的编码向量之间相互正交(内积为 0),避免了人为赋予的数值大小带来的偏见。

示例

假设我们有一个离散特征 “颜色”,包含 3 个类别:红色、蓝色、绿色。

  • 红色 → [1, 0, 0]
  • 蓝色 → [0, 1, 0]
  • 绿色 → [0, 0, 1]

再比如自然语言处理中,若词汇表为 ["我", "爱", "机器学习"],则:

  • “我” → [1, 0, 0]
  • “爱” → [0, 1, 0]
  • “机器学习” → [0, 0, 1]

token 值的设定问题

1. 为什么 token 可以等于 'char' 或 'word'
  • 人为约定的参数值
    这两个字符串是开发者为了区分功能而定义的 “标识符”。就像函数参数 mode='train' 表示训练模式、mode='test' 表示测试模式一样,token='char' 和 token='word' 只是告诉函数 “按字符拆分” 或 “按单词拆分”。
    你也可以换成其他字符串(如 'character' 或 'term'),只要函数内部逻辑对应即可,但 'char' 和 'word' 是行业通用的简洁命名。

  • 两种最基础的词元划分方式
    在自然语言处理中,词元(token)是文本的最小处理单位,但 “最小” 的定义取决于任务:

    • 字符级('char':将每个字符作为词元(如中文的每个汉字、英文的每个字母)。
      例:"hello" → ['h', 'e', 'l', 'l', 'o']
    • 单词级('word':将空格分隔的单词作为词元(适用于英文等有天然空格分隔的语言)。
      例:"hello world" → ['hello', 'world']
2. 为什么需要区分这两种方式?

不同任务对词元的粒度要求不同:

  • 字符级('char')适用场景

    • 处理拼写纠错(需关注单个字母的错误)。
    • 生成类任务(如文本生成、OCR 识别),需要精确到字符的预测。
    • 语言结构简单或无空格分隔的语言(如中文、日文)。
      优点:词表规模小(英文仅 26 个字母 + 符号),不存在未登录词(OOV)问题;
      缺点:忽略单词语义,需要更长的序列建模上下文。
  • 单词级('word')适用场景

    • 文本分类、情感分析(需基于单词语义)。
    • 机器翻译(以单词为基本单位映射)。
      优点:直接对应语义单元,建模效率高;
      缺点:词表规模大(可能包含数万甚至数十万单词),存在大量未登录词(如专业术语、新词)。
3. 本质:参数值是 “功能开关”

token 参数的取值本身没有特殊含义,它的作用是告诉函数应该执行哪种逻辑。例如:

  • 当 token='char' 时,函数执行 list(line)(将字符串拆分为字符列表);
  • 当 token='word' 时,函数执行 line.split()(按空格拆分单词)。

这种设计让一个函数能兼容两种常用的词元化逻辑,避免重复编写相似代码。如果需要,你还可以扩展参数值(如 'subword' 表示子词级拆分),只需在函数中增加对应的处理逻辑即可。

为什么需要生成函数?

# 定义迭代器生成函数(每次调用生成新的迭代器,避免被提前消费)def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)# 生成随机抽样的迭代器train_iter = seq_data_iter_random(corpus, batch_size, num_steps)   区别在哪

1. 关键区别

方式一:返回迭代器生成函数
def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)
 
  • 特性data_iter 是一个函数,每次调用它都会生成一个全新的迭代器
  • 优势:可重复使用。训练时每次需要数据(如每个 epoch),都调用 data_iter() 生成新迭代器,避免数据被提前消费。
  • 应用场景:适用于需要多次遍历数据的场景(如训练多轮)。
方式二:直接返回迭代器
train_iter = seq_data_iter_random(corpus, batch_size, num_steps)
  • 问题:若在训练前用 next(iter(train_iter)) 验证数据,会消耗迭代器中的第一个批次,导致训练时少一个批次;若多次验证或多轮训练,可能导致数据耗尽,出现 警告:没有处理任何训练批次!
  • 应用场景:仅适用于一次性遍历数据的场景(如只需读取一次数据)。

2. 实际应用对比

使用生成函数
# 定义生成函数
def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)# 训练前验证数据(使用临时迭代器)
temp_iter = data_iter()
X, Y = next(iter(temp_iter))  # 验证不会影响后续训练# 训练时每个 epoch 都调用生成函数获取新迭代器
for epoch in range(num_epochs):train_iter = data_iter()  # 每次都生成全新的迭代器for X, Y in train_iter:# 训练逻辑pass
直接使用迭代器
# 直接生成迭代器
train_iter = seq_data_iter_random(corpus, batch_size, num_steps)# 训练前验证数据(消耗了第一个批次)
X, Y = next(iter(train_iter))  # 这一步后,train_iter 已少了一个批次# 训练时(第一个批次已被验证消耗)
for X, Y in train_iter:# 实际训练会从第二个批次开始pass

3. 总结

方式是否可重复使用适合场景问题
返回迭代器生成函数✅ 是(每次生成新迭代器)多轮训练、多次验证
直接返回迭代器❌ 否(只能用一次)一次性数据读取数据提前消费,训练时批次丢失

 完整代码

"""
文件名: 8.5
作者: 墨尘
日期: 2025/7/14
项目名: dl_env
备注: 修复迭代器被提前消费的问题,确保训练时能获取批次数据
"""import random
import collections
import re
import requests
from pathlib import Path
from d2l import torch as d2l
import math
import torch
from torch import nn
from torch.nn import functional as F
# 手动显示图像相关库
import matplotlib.pyplot as plt  # 绘图库
import matplotlib.text as text  # 用于修改文本绘制(解决符号显示问题)# -------------------------- 核心解决方案:解决文本显示问题 --------------------------
def replace_minus(s):"""解决Matplotlib中Unicode减号(U+2212)显示异常的问题参数:s: 待处理的字符串或其他对象返回:处理后的字符串(替换减号)或原始对象"""if isinstance(s, str):  # 仅处理字符串类型return s.replace('\u2212', '-')  # 将特殊减号替换为普通减号return s  # 非字符串直接返回# 重写matplotlib的Text类的set_text方法,全局解决减号显示问题
original_set_text = text.Text.set_text  # 保存原始方法
def new_set_text(self, s):s = replace_minus(s)  # 先处理减号return original_set_text(self, s)  # 调用原始方法设置文本
text.Text.set_text = new_set_text  # 应用重写后的方法# -------------------------- 字体配置(确保中文和数学符号正常显示)--------------------------
plt.rcParams["font.family"] = ["SimHei"]  # 设置中文字体(支持中文显示)
plt.rcParams["text.usetex"] = True  # 使用LaTeX渲染文本(提升数学符号美观度)
plt.rcParams["axes.unicode_minus"] = True  # 确保负号正确显示(避免方块)
plt.rcParams["mathtext.fontset"] = "cm"  # 数学符号使用Computer Modern字体
d2l.plt.rcParams.update(plt.rcParams)  # 让d2l库的绘图工具继承配置# -------------------------- 1. 读取数据 --------------------------
def read_time_machine():"""下载并读取《时间机器》数据集"""data_dir = Path('./data')data_dir.mkdir(exist_ok=True)file_path = data_dir / 'timemachine.txt'# 检查文件是否存在if not file_path.exists():print("开始下载时间机器数据集...")response = requests.get('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')with open(file_path, 'w', encoding='utf-8') as f:f.write(response.text)print(f"数据集下载完成,保存至: {file_path}")# 验证文件内容with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines()print(f"文件读取成功,总行数: {len(lines)}")if len(lines) > 0:print(f"第一行内容: {lines[0].strip()}")# 清洗文本:保留字母、转小写、去空行cleaned_lines = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines if line.strip()]print(f"清洗后有效行数: {len(cleaned_lines)}")return cleaned_lines# -------------------------- 2. 词元化与词表构建 --------------------------
def tokenize(lines, token='char'):"""将文本行转换为词元列表(词元是文本的最小处理单位)参数:lines: 清洗后的文本行列表token: 词元类型('char'字符级/'word'单词级)"""if token == 'char':return [list(line) for line in lines]  # 字符级词元化(每个字符为词元)elif token == 'word':return [line.split() for line in lines]  # 单词级词元化(按空格拆分)else:raise ValueError('未知词元类型:' + token)class Vocab:"""词表类:映射词元与索引"""def __init__(self, tokens, min_freq=0, reserved_tokens=None):"""构建词表:过滤低频词,保留特殊词元,建立双向映射参数:tokens: 词元列表(可嵌套)min_freq: 最低词频阈值reserved_tokens: 预留特殊词元(如<unk>)"""if reserved_tokens is None:reserved_tokens = []# 统计词频(展平列表)counter = collections.Counter([token for line in tokens for token in line])self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)  # 按频率排序# 初始化词表:<unk>在索引0, followed by预留词元self.idx_to_token = ['<unk>'] + reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 按词频添加词元(过滤低频词)for token, freq in self.token_freqs:if freq < min_freq:break  # 低频词不加入词表if token not in self.token_to_idx:  # 避免重复添加预留词元self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1def __len__(self):return len(self.idx_to_token)  # 词表大小def __getitem__(self, tokens):"""词元→索引(支持单个/列表)"""if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)  # 未知词元返回<unk>的索引return [self.__getitem__(token) for token in tokens]def to_tokens(self, indices):"""索引→词元(支持单个/列表)"""if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0  # <unk>的固定索引# -------------------------- 3. 数据迭代器(随机抽样) --------------------------
def seq_data_iter_random(corpus, batch_size, num_steps):"""随机抽样生成批量子序列(生成器)注意:生成器是一次性的,遍历后即耗尽参数:corpus: 词元索引序列(1D列表)batch_size: 批量大小num_steps: 子序列长度(时间步)"""# 检查数据是否足够if len(corpus) < num_steps + 1:raise ValueError(f"语料库长度({len(corpus)})不足,需至少{num_steps+1}")# 随机偏移起始位置(增加随机性)corpus = corpus[random.randint(0, num_steps - 1):]# 计算可生成的子序列数量(-1是因为标签需右移1位)num_subseqs = (len(corpus) - 1) // num_stepsif num_subseqs < 1:raise ValueError(f"无法生成子序列(语料库长度不足)")# 生成子序列起始索引(间隔num_steps)initial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)  # 打乱索引,随机抽样# 计算批次数num_batches = num_subseqs // batch_sizeif num_batches < 1:raise ValueError(f"子序列数量({num_subseqs})不足,需至少{batch_size}个")# print(f"数据迭代器: 语料库长度={len(corpus)}, 子序列数={num_subseqs}, 批次数={num_batches}")# 生成批量数据for i in range(0, batch_size * num_batches, batch_size):indices = initial_indices[i: i + batch_size]  # 当前批次的起始索引X = [corpus[j: j + num_steps] for j in indices]  # 输入序列Y = [corpus[j + 1: j + num_steps + 1] for j in indices]  # 标签(右移1位)yield torch.tensor(X), torch.tensor(Y)  # 返回张量# -------------------------- 4. 数据加载函数(关键修复:返回可重置的迭代器) --------------------------
def load_data_time_machine(batch_size, num_steps):"""加载数据并返回迭代器和词表修复点:返回迭代器生成函数,而非一次性迭代器,确保训练时可重新生成数据"""lines = read_time_machine()  # 读取清洗后的文本tokens = tokenize(lines, token='char')  # 字符级词元化vocab = Vocab(tokens)  # 构建词表# 转换为词元索引序列(展平)corpus = [vocab[token] for line in tokens for token in line]print(f"语料库长度: {len(corpus)}(词元索引总数)")# 定义迭代器生成函数(每次调用生成新的迭代器,避免被提前消费)def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)# 生成随机抽样的迭代器#train_iter = seq_data_iter_random(corpus, batch_size, num_steps)return data_iter, vocab  # 返回生成函数而非迭代器# -------------------------- 5. RNN模型参数与前向传播 --------------------------
def get_params(vocab_size, num_hiddens, device):"""初始化RNN参数(输入→隐藏/隐藏→隐藏/隐藏→输出)"""num_inputs = num_outputs = vocab_size  # 输入/输出维度=词表大小def normal(shape):"""正态分布初始化参数(均值0,标准差0.01)"""return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))  # 输入→隐藏权重 (V, H)W_hh = normal((num_hiddens, num_hiddens))  # 隐藏→隐藏权重 (H, H)b_h = torch.zeros(num_hiddens, device=device)  # 隐藏层偏置 (H,)# 输出层参数W_hq = normal((num_hiddens, num_outputs))  # 隐藏→输出权重 (H, V)b_q = torch.zeros(num_outputs, device=device)  # 输出层偏置 (V,)# 启用梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_rnn_state(batch_size, num_hiddens, device):"""初始化隐藏状态(全零向量,返回元组便于扩展)"""return (torch.zeros((batch_size, num_hiddens), device=device), )def rnn(inputs, state, params):"""RNN前向传播(逐时间步计算)参数:inputs: 输入序列 (num_steps, batch_size, vocab_size)state: 初始隐藏状态 (batch_size, num_hiddens)params: 模型参数返回:outputs: 所有时间步的输出 (num_steps*batch_size, vocab_size)state: 最终隐藏状态 (batch_size, num_hiddens)"""W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state  # 初始隐藏状态 (B, H)outputs = []  # 存储每个时间步的输出for X in inputs:  # X形状 (B, V)# 计算新隐藏状态:H_t = tanh(X_t·W_xh + H_{t-1}·W_hh + b_h)H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)# 计算输出:Y_t = H_t·W_hq + b_qY = torch.mm(H, W_hq) + b_qoutputs.append(Y)  # 保存当前输出# 拼接所有时间步的输出,返回输出和最终状态return torch.cat(outputs, dim=0), (H,)# -------------------------- 7. RNN模型包装类 --------------------------
class RNNModelScratch:  #@save"""从零实现的RNN模型包装类"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):"""参数:vocab_size: 词表大小num_hiddens: 隐藏层大小device: 计算设备get_params: 参数初始化函数init_state: 状态初始化函数forward_fn: 前向传播函数"""self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)  # 模型参数self.init_state, self.forward_fn = init_state, forward_fn  # 状态和前向传播函数def __call__(self, X, state):"""模型调用(前向传播入口)参数:X: 输入序列 (batch_size, num_steps)state: 初始隐藏状态返回:y_hat: 输出 (num_steps*batch_size, vocab_size)state: 最终隐藏状态"""# 处理输入:转置→one-hot编码→转float32X = F.one_hot(X.T, self.vocab_size).type(torch.float32)  # (num_steps, batch_size, vocab_size)return self.forward_fn(X, state, self.params)  # 调用前向传播def begin_state(self, batch_size, device):"""获取初始隐藏状态"""return self.init_state(batch_size, self.num_hiddens, device)# -------------------------- 8. 预测函数 --------------------------
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""根据前缀生成后续字符(文本生成)参数:prefix: 前缀字符串num_preds: 生成的字符数net: RNN模型vocab: 词表device: 设备返回:生成的字符串(前缀+预测字符)"""state = net.begin_state(batch_size=1, device=device)  # 初始化状态(批量大小1)outputs = [vocab[prefix[0]]]  # 记录输出索引(初始为前缀首字符)def get_input():"""获取当前输入(最后一个输出的索引,形状(1,1))"""return torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 预热期:用前缀更新状态(不生成新字符)for y in prefix[1:]:_, state = net(get_input(), state)  # 更新状态outputs.append(vocab[y])  # 记录前缀字符# 预测期:生成num_preds个字符for _ in range(num_preds):y, state = net(get_input(), state)  # 前向传播outputs.append(int(y.argmax(dim=1).reshape(1)))  # 取概率最大的字符return ''.join([vocab.idx_to_token[i] for i in outputs])  # 索引→字符# -------------------------- 9. 梯度裁剪(防止梯度爆炸) --------------------------
def grad_clipping(net, theta):  #@save"""裁剪梯度(将梯度 norm 限制在theta内)参数:net: 模型theta: 梯度阈值"""# 获取需要梯度的参数if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.params  # 自定义模型# 计算梯度L2 normnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:  # 若超过阈值,按比例裁剪for param in params:param.grad[:] *= theta / norm# -------------------------- 10. 训练函数(核心修复:每次迭代重新生成迭代器) --------------------------
def train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter):"""训练一个周期(修复:接收迭代器生成函数,每次重新生成迭代器)参数:train_iter_fn: 迭代器生成函数(调用后返回新迭代器)其他参数同上"""state, timer = None, d2l.Timer()metric = d2l.Accumulator(2)  # (总损失, 总词元数)batches_processed = 0  # 记录处理的批次# 关键:每次训练都通过函数生成新的迭代器(避免迭代器被提前消费)train_iter = train_iter_fn()for X, Y in train_iter:batches_processed += 1# 初始化状态(首次迭代或随机抽样时)if state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:# 分离状态,避免梯度回流到之前的批次if isinstance(net, nn.Module) and not isinstance(state, tuple):state.detach_()else:for s in state:s.detach_()# 处理标签:转置后展平(与输出形状匹配)y = Y.T.reshape(-1)X, y = X.to(device), y.to(device)# 前向传播y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()  # 计算损失# 反向传播与更新if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)  # 裁剪梯度updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)# 累加损失和词元数metric.add(l * y.numel(), y.numel())# 检查是否有批次被处理if batches_processed == 0:print("警告:没有处理任何训练批次!")return float('inf'), 0# 返回困惑度和速度return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()def train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=False):"""训练模型(多个周期)参数:train_iter_fn: 迭代器生成函数(每次调用生成新迭代器)"""loss = nn.CrossEntropyLoss()  # 交叉熵损失(分类任务)animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化优化器if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)# 预测函数(生成50个字符)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 训练多个周期for epoch in range(num_epochs):ppl, speed = train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter)# 每10个周期打印预测结果if (epoch + 1) % 10 == 0:print(f"epoch {epoch+1} 预测: {predict('time traveller')}")animator.add(epoch + 1, [ppl])# 训练结束后输出结果print(f'最终困惑度 {ppl:.1f}, 速度 {speed:.1f} 词元/秒 {device}')print(f"time traveller 预测: {predict('time traveller')}")print(f"traveller 预测: {predict('traveller')}")# -------------------------- 11. 测试代码(核心修复:使用迭代器生成函数) --------------------------
if __name__ == '__main__':# 超参数batch_size, num_steps = 32, 35print(f"配置: batch_size={batch_size}, num_steps={num_steps}")# 加载数据(返回迭代器生成函数,而非直接返回迭代器)train_iter_fn, vocab = load_data_time_machine(batch_size, num_steps)# 验证数据(使用新生成的迭代器,避免消费训练数据)try:temp_iter = train_iter_fn()  # 生成临时迭代器用于验证X, Y = next(iter(temp_iter))print(f"数据验证: X形状={X.shape}, Y形状={Y.shape}")print(f"第一个样本X(前10词元): {X[0][:10]}")print(f"对应的Y: {Y[0][:10]}")except StopIteration:print("错误: 迭代器为空!")exit(1)# 初始化模型num_hiddens = 512device = d2l.try_gpu()net = RNNModelScratch(len(vocab),num_hiddens,device,get_params,init_rnn_state,rnn)# 测试模型输出形状state = net.begin_state(X.shape[0], device)Y_hat, new_state = net(X.to(device), state)print(f"输出Y形状: {Y_hat.shape} (batch_size*num_steps, vocab_size)")print(f"新状态形状: {new_state[0].shape} (batch_size, num_hiddens)")# 训练前预测(结果可能无意义)print("\n训练前预测:", predict_ch8('time traveller ', 10, net, vocab, device))# 训练模型(使用迭代器生成函数,每次 epoch 生成新迭代器)num_epochs, lr = 500, 0.1  #5e-4 print(f"\n开始训练: 迭代{num_epochs}次,学习率{lr}")train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device)plt.show(block=True)# 使用随机迭代再次训练(对比效果)print("\n使用随机迭代训练:")net = RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_rnn_state, rnn)train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=True)plt.show(block=True)

实验结果

当学习率为0.1时

训练模型

随机迭代训练 

当学习率为0.0005时

训练模型

随机迭代训练 

 

 

http://www.lryc.cn/news/587613.html

相关文章:

  • 使用 keytool 在服务器上导入证书操作指南(SSL 证书验证错误处理)
  • kafka的部署
  • Android系统的问题分析笔记 - Android上的调试方式 bugreport
  • 论文阅读:WildGS-SLAM:Monocular Gaussian Splatting SLAM in Dynamic Environments
  • 深入浅出Kafka Consumer源码解析:设计哲学与实现艺术
  • Angular 框架下 AI 驱动的企业级大前端应用开
  • Kafka 时间轮深度解析:如何O(1)处理定时任务
  • 【Python】-实用技巧5- 如何使用Python处理文件和目录
  • 计算机网络通信的相关知识总结
  • 基于GA遗传优化的多边形拟合算法matlab仿真
  • vscode/cursor怎么自定义文字、行高、颜色
  • PHP password_hash() 函数
  • 仓储智能穿梭车:提升仓库效率50%的自动化核心设备
  • Ubuntu系统下Conda的详细安装教程与Python多版本管理指南
  • 【软件架构】软件体系结构风格实现
  • I2C设备寄存器读取调试方法
  • 卷绕/叠片工艺
  • React源码3:update、fiber.updateQueue对象数据结构和updateContainer()中enqueueUpdate()阶段
  • 新手向:Python自动化办公批量重命名与整理文件系统
  • 理解:进程、线程、协程
  • LLM表征工程还有哪些值得做的地方
  • python的小学课外综合管理系统
  • 我对muduo的梳理以及AI的更改
  • MFC UI表格制作从专家到入门
  • LeetCode经典题解:206、两数之和(Two Sum)
  • 018 进程控制 —— 进程等待
  • 算法训练营day18 530.二叉搜索树的最小绝对差、501.二叉搜索树中的众数、236. 二叉树的最近公共祖先
  • B站自动回复工具(破解)
  • 项目一第一天
  • 苍穹外卖学习指南(java的一个项目)(老师能运行,但你不行,看这里!!)