PyTorch LSTM文本生成
PyTorch LSTM文本生成
1. 环境准备和导入
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import Counter
import string
import re# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
2. 数据准备(使用简化的文本数据)
# 使用莎士比亚文本作为示例(可以替换为WikiText-2或其他数据集)
sample_text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them. To die—to sleep,
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to: 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep, perchance to dream—ay, there's the rub:
For in that sleep of death what dreams may come,
When we have shuffled off this mortal coil,
Must give us pause—there's the respect
That makes calamity of so long life.
"""class TextDataset(Dataset):def __init__(self, text, seq_length=40):"""文本数据集类Args:text: 输入文本seq_length: 序列长度"""self.seq_length = seq_lengthself.text = self.preprocess_text(text)# 创建字符到索引的映射self.chars = sorted(list(set(self.text)))self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}self.vocab_size = len(self.chars)print(f"文本长度: {len(self.text)}")print(f"词汇表大小: {self.vocab_size}")print(f"示例字符: {self.chars[:20]}")# 准备训练数据self.prepare_data()def preprocess_text(self, text):"""预处理文本"""# 转换为小写并保留基本标点text = text.lower().strip()# 移除多余空格text = re.sub(r'\s+', ' ', text)return textdef prepare_data(self):"""准备输入输出序列"""self.inputs = []self.targets = []for i in range(len(self.text) - self.seq_length):input_seq = self.text[i:i + self.seq_length]target_seq = self.text[i + 1:i + self.seq_length + 1]self.inputs.append([self.char_to_idx[ch] for ch in input_seq])self.targets.append([self.char_to_idx[ch] for ch in target_seq])def __len__(self):return len(self.inputs)def __getitem__(self, idx):return (torch.tensor(self.inputs[idx], dtype=torch.long),torch.tensor(self.targets[idx], dtype=torch.long))
3. LSTM模型定义
class LSTMGenerator(nn.Module):def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.2):"""LSTM文本生成模型Args:vocab_size: 词汇表大小embedding_dim: 嵌入维度hidden_dim: 隐藏层维度num_layers: LSTM层数dropout: Dropout率"""super(LSTMGenerator, self).__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.hidden_dim = hidden_dimself.num_layers = num_layers# 嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim)# LSTM层self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,dropout=dropout if num_layers > 1 else 0)# Dropout层self.dropout = nn.Dropout(dropout)# 输出层self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):"""前向传播Args:x: 输入序列 [batch_size, seq_length]hidden: 隐藏状态"""batch_size = x.size(0)# 嵌入embedded = self.embedding(x) # [batch_size, seq_length, embedding_dim]embedded = self.dropout(embedded)# LSTMif hidden is None:hidden = self.init_hidden(batch_size, x.device)lstm_out, hidden = self.lstm(embedded, hidden)lstm_out = self.dropout(lstm_out)# 输出层output = self.fc(lstm_out) # [batch_size, seq_length, vocab_size]return output, hiddendef init_hidden(self, batch_size, device):"""初始化隐藏状态"""h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)return (h0, c0)
4. 训练函数
def train_model(model, dataset, epochs=100, batch_size=64, lr=0.001):"""训练模型"""dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)model.train()losses = []for epoch in range(epochs):epoch_loss = 0batch_count = 0for batch_idx, (inputs, targets) in enumerate(dataloader):inputs, targets = inputs.to(device), targets.to(device)# 前向传播optimizer.zero_grad()output, _ = model(inputs)# 计算损失loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))# 反向传播loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 5) # 梯度裁剪optimizer.step()epoch_loss += loss.item()batch_count += 1avg_loss = epoch_loss / batch_countlosses.append(avg_loss)scheduler.step(avg_loss)if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')return losses
5. 文本生成函数
def generate_text(model, dataset, seed_text, length=200, temperature=1.0):"""生成文本Args:model: 训练好的模型dataset: 数据集(用于字符映射)seed_text: 种子文本length: 生成长度temperature: 温度参数(控制随机性)"""model.eval()# 预处理种子文本seed_text = seed_text.lower()# 转换为索引input_seq = [dataset.char_to_idx.get(ch, 0) for ch in seed_text]generated_text = seed_textwith torch.no_grad():for _ in range(length):# 准备输入if len(input_seq) > dataset.seq_length:input_seq = input_seq[-dataset.seq_length:]x = torch.tensor([input_seq], dtype=torch.long).to(device)# 预测output, _ = model(x)output = output[0, -1, :] / temperature# 采样probabilities = F.softmax(output, dim=0)next_idx = torch.multinomial(probabilities, 1).item()# 添加到序列input_seq.append(next_idx)generated_text += dataset.idx_to_char[next_idx]return generated_text
6. 主训练流程
def main():# 创建数据集dataset = TextDataset(sample_text, seq_length=40)# 创建模型model = LSTMGenerator(vocab_size=dataset.vocab_size,embedding_dim=128,hidden_dim=256,num_layers=2,dropout=0.2).to(device)print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")# 训练模型print("\n开始训练...")losses = train_model(model, dataset, epochs=100, batch_size=32, lr=0.001)# 生成文本示例print("\n生成文本示例:")print("-" * 50)# 不同温度参数的生成temperatures = [0.5, 0.8, 1.0, 1.2]seed_texts = ["to be", "the heart", "sleep"]for seed in seed_texts:print(f"\n种子文本: '{seed}'")for temp in temperatures:generated = generate_text(model, dataset, seed, length=100, temperature=temp)print(f"Temperature {temp}: {generated}")print()# 绘制损失曲线import matplotlib.pyplot as pltplt.figure(figsize=(10, 5))plt.plot(losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()return model, datasetif __name__ == "__main__":model, dataset = main()
7. 高级功能:条件文本生成
class ConditionalLSTM(nn.Module):"""带条件的LSTM生成器(如情感、风格等)"""def __init__(self, vocab_size, num_conditions, embedding_dim=128, hidden_dim=256, condition_dim=32):super(ConditionalLSTM, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.condition_embedding = nn.Embedding(num_conditions, condition_dim)# LSTM输入包含文本嵌入和条件嵌入self.lstm = nn.LSTM(embedding_dim + condition_dim,hidden_dim,num_layers=2,batch_first=True,dropout=0.2)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, condition):# 获取嵌入text_embedded = self.embedding(x)cond_embedded = self.condition_embedding(condition)# 扩展条件嵌入以匹配序列长度cond_embedded = cond_embedded.unsqueeze(1).expand(-1, text_embedded.size(1), -1)# 连接嵌入combined = torch.cat([text_embedded, cond_embedded], dim=-1)# LSTM和输出lstm_out, _ = self.lstm(combined)output = self.fc(lstm_out)return output
8. 评估和可视化
def evaluate_model(model, dataset, num_samples=5):"""评估模型生成质量"""model.eval()# 计算困惑度dataloader = DataLoader(dataset, batch_size=32, shuffle=False)criterion = nn.CrossEntropyLoss()total_loss = 0total_count = 0with torch.no_grad():for inputs, targets in dataloader:inputs, targets = inputs.to(device), targets.to(device)output, _ = model(inputs)loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))total_loss += loss.item() * inputs.size(0)total_count += inputs.size(0)perplexity = np.exp(total_loss / total_count)print(f"困惑度 (Perplexity): {perplexity:.2f}")# 生成多样性评估generated_samples = []for _ in range(num_samples):seed = random.choice(["to ", "the ", "and "])text = generate_text(model, dataset, seed, length=100, temperature=0.8)generated_samples.append(text)# 计算唯一n-gramdef get_ngrams(text, n):return set([text[i:i+n] for i in range(len(text)-n+1)])all_bigrams = set()all_trigrams = set()for text in generated_samples:all_bigrams.update(get_ngrams(text, 2))all_trigrams.update(get_ngrams(text, 3))print(f"唯一2-gram数: {len(all_bigrams)}")print(f"唯一3-gram数: {len(all_trigrams)}")return perplexity, generated_samples
使用说明
-
数据准备:代码使用简化的莎士比亚文本,可以替换为:
- WikiText-2/WikiText-103
- Penn Treebank
- 任何文本文件
-
模型配置:
- 调整
embedding_dim
和hidden_dim
控制模型容量 - 增加
num_layers
提高模型复杂度 - 调整
temperature
控制生成随机性
- 调整
-
训练技巧:
- 使用梯度裁剪防止梯度爆炸
- 使用学习率调度器自适应调整学习率
- 适当的dropout防止过拟合
-
生成策略:
- Temperature采样:控制输出分布的尖锐程度
- Top-k采样:只从概率最高的k个字符中采样
- Beam搜索:生成多个候选序列并选择最优
-
使用预训练模型:如GPT-2、BERT等
-
添加注意力机制:提高长序列建模能力
-
实现GAN架构:生成对抗网络提高生成质量
-
多任务学习:同时训练多种文本生成任务
核心功能
字符级LSTM模型:支持任意文本数据
温度控制采样:调节生成文本的随机性
条件生成:可扩展为带条件(情感、风格)的生成
完整训练流程:包含优化器、学习率调度器
技术亮点
梯度裁剪:防止梯度爆炸
Dropout正则化:防止过拟合
困惑度评估:量化生成质量
多样性分析:n-gram统计
可扩展性
支持替换为WikiText-2、Penn Treebank等公开数据集
可集成注意力机制、Transformer架构
支持Beam搜索、Top-k采样等高级生成策略