当前位置: 首页 > news >正文

PyTorch LSTM文本生成

PyTorch LSTM文本生成

1. 环境准备和导入

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import Counter
import string
import re# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

2. 数据准备(使用简化的文本数据)

# 使用莎士比亚文本作为示例(可以替换为WikiText-2或其他数据集)
sample_text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them. To die—to sleep,
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to: 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep, perchance to dream—ay, there's the rub:
For in that sleep of death what dreams may come,
When we have shuffled off this mortal coil,
Must give us pause—there's the respect
That makes calamity of so long life.
"""class TextDataset(Dataset):def __init__(self, text, seq_length=40):"""文本数据集类Args:text: 输入文本seq_length: 序列长度"""self.seq_length = seq_lengthself.text = self.preprocess_text(text)# 创建字符到索引的映射self.chars = sorted(list(set(self.text)))self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}self.vocab_size = len(self.chars)print(f"文本长度: {len(self.text)}")print(f"词汇表大小: {self.vocab_size}")print(f"示例字符: {self.chars[:20]}")# 准备训练数据self.prepare_data()def preprocess_text(self, text):"""预处理文本"""# 转换为小写并保留基本标点text = text.lower().strip()# 移除多余空格text = re.sub(r'\s+', ' ', text)return textdef prepare_data(self):"""准备输入输出序列"""self.inputs = []self.targets = []for i in range(len(self.text) - self.seq_length):input_seq = self.text[i:i + self.seq_length]target_seq = self.text[i + 1:i + self.seq_length + 1]self.inputs.append([self.char_to_idx[ch] for ch in input_seq])self.targets.append([self.char_to_idx[ch] for ch in target_seq])def __len__(self):return len(self.inputs)def __getitem__(self, idx):return (torch.tensor(self.inputs[idx], dtype=torch.long),torch.tensor(self.targets[idx], dtype=torch.long))

3. LSTM模型定义

class LSTMGenerator(nn.Module):def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.2):"""LSTM文本生成模型Args:vocab_size: 词汇表大小embedding_dim: 嵌入维度hidden_dim: 隐藏层维度num_layers: LSTM层数dropout: Dropout率"""super(LSTMGenerator, self).__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.hidden_dim = hidden_dimself.num_layers = num_layers# 嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim)# LSTM层self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,dropout=dropout if num_layers > 1 else 0)# Dropout层self.dropout = nn.Dropout(dropout)# 输出层self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):"""前向传播Args:x: 输入序列 [batch_size, seq_length]hidden: 隐藏状态"""batch_size = x.size(0)# 嵌入embedded = self.embedding(x)  # [batch_size, seq_length, embedding_dim]embedded = self.dropout(embedded)# LSTMif hidden is None:hidden = self.init_hidden(batch_size, x.device)lstm_out, hidden = self.lstm(embedded, hidden)lstm_out = self.dropout(lstm_out)# 输出层output = self.fc(lstm_out)  # [batch_size, seq_length, vocab_size]return output, hiddendef init_hidden(self, batch_size, device):"""初始化隐藏状态"""h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)return (h0, c0)

4. 训练函数

def train_model(model, dataset, epochs=100, batch_size=64, lr=0.001):"""训练模型"""dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)model.train()losses = []for epoch in range(epochs):epoch_loss = 0batch_count = 0for batch_idx, (inputs, targets) in enumerate(dataloader):inputs, targets = inputs.to(device), targets.to(device)# 前向传播optimizer.zero_grad()output, _ = model(inputs)# 计算损失loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))# 反向传播loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 5)  # 梯度裁剪optimizer.step()epoch_loss += loss.item()batch_count += 1avg_loss = epoch_loss / batch_countlosses.append(avg_loss)scheduler.step(avg_loss)if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')return losses

5. 文本生成函数

def generate_text(model, dataset, seed_text, length=200, temperature=1.0):"""生成文本Args:model: 训练好的模型dataset: 数据集(用于字符映射)seed_text: 种子文本length: 生成长度temperature: 温度参数(控制随机性)"""model.eval()# 预处理种子文本seed_text = seed_text.lower()# 转换为索引input_seq = [dataset.char_to_idx.get(ch, 0) for ch in seed_text]generated_text = seed_textwith torch.no_grad():for _ in range(length):# 准备输入if len(input_seq) > dataset.seq_length:input_seq = input_seq[-dataset.seq_length:]x = torch.tensor([input_seq], dtype=torch.long).to(device)# 预测output, _ = model(x)output = output[0, -1, :] / temperature# 采样probabilities = F.softmax(output, dim=0)next_idx = torch.multinomial(probabilities, 1).item()# 添加到序列input_seq.append(next_idx)generated_text += dataset.idx_to_char[next_idx]return generated_text

6. 主训练流程

def main():# 创建数据集dataset = TextDataset(sample_text, seq_length=40)# 创建模型model = LSTMGenerator(vocab_size=dataset.vocab_size,embedding_dim=128,hidden_dim=256,num_layers=2,dropout=0.2).to(device)print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")# 训练模型print("\n开始训练...")losses = train_model(model, dataset, epochs=100, batch_size=32, lr=0.001)# 生成文本示例print("\n生成文本示例:")print("-" * 50)# 不同温度参数的生成temperatures = [0.5, 0.8, 1.0, 1.2]seed_texts = ["to be", "the heart", "sleep"]for seed in seed_texts:print(f"\n种子文本: '{seed}'")for temp in temperatures:generated = generate_text(model, dataset, seed, length=100, temperature=temp)print(f"Temperature {temp}: {generated}")print()# 绘制损失曲线import matplotlib.pyplot as pltplt.figure(figsize=(10, 5))plt.plot(losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()return model, datasetif __name__ == "__main__":model, dataset = main()

7. 高级功能:条件文本生成

class ConditionalLSTM(nn.Module):"""带条件的LSTM生成器(如情感、风格等)"""def __init__(self, vocab_size, num_conditions, embedding_dim=128, hidden_dim=256, condition_dim=32):super(ConditionalLSTM, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.condition_embedding = nn.Embedding(num_conditions, condition_dim)# LSTM输入包含文本嵌入和条件嵌入self.lstm = nn.LSTM(embedding_dim + condition_dim,hidden_dim,num_layers=2,batch_first=True,dropout=0.2)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, condition):# 获取嵌入text_embedded = self.embedding(x)cond_embedded = self.condition_embedding(condition)# 扩展条件嵌入以匹配序列长度cond_embedded = cond_embedded.unsqueeze(1).expand(-1, text_embedded.size(1), -1)# 连接嵌入combined = torch.cat([text_embedded, cond_embedded], dim=-1)# LSTM和输出lstm_out, _ = self.lstm(combined)output = self.fc(lstm_out)return output

8. 评估和可视化

def evaluate_model(model, dataset, num_samples=5):"""评估模型生成质量"""model.eval()# 计算困惑度dataloader = DataLoader(dataset, batch_size=32, shuffle=False)criterion = nn.CrossEntropyLoss()total_loss = 0total_count = 0with torch.no_grad():for inputs, targets in dataloader:inputs, targets = inputs.to(device), targets.to(device)output, _ = model(inputs)loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))total_loss += loss.item() * inputs.size(0)total_count += inputs.size(0)perplexity = np.exp(total_loss / total_count)print(f"困惑度 (Perplexity): {perplexity:.2f}")# 生成多样性评估generated_samples = []for _ in range(num_samples):seed = random.choice(["to ", "the ", "and "])text = generate_text(model, dataset, seed, length=100, temperature=0.8)generated_samples.append(text)# 计算唯一n-gramdef get_ngrams(text, n):return set([text[i:i+n] for i in range(len(text)-n+1)])all_bigrams = set()all_trigrams = set()for text in generated_samples:all_bigrams.update(get_ngrams(text, 2))all_trigrams.update(get_ngrams(text, 3))print(f"唯一2-gram数: {len(all_bigrams)}")print(f"唯一3-gram数: {len(all_trigrams)}")return perplexity, generated_samples

使用说明

  1. 数据准备:代码使用简化的莎士比亚文本,可以替换为:

    • WikiText-2/WikiText-103
    • Penn Treebank
    • 任何文本文件
  2. 模型配置

    • 调整embedding_dimhidden_dim控制模型容量
    • 增加num_layers提高模型复杂度
    • 调整temperature控制生成随机性
  3. 训练技巧

    • 使用梯度裁剪防止梯度爆炸
    • 使用学习率调度器自适应调整学习率
    • 适当的dropout防止过拟合
  4. 生成策略

    • Temperature采样:控制输出分布的尖锐程度
    • Top-k采样:只从概率最高的k个字符中采样
    • Beam搜索:生成多个候选序列并选择最优
  5. 使用预训练模型:如GPT-2、BERT等

  6. 添加注意力机制:提高长序列建模能力

  7. 实现GAN架构:生成对抗网络提高生成质量

  8. 多任务学习:同时训练多种文本生成任务

核心功能

字符级LSTM模型:支持任意文本数据
温度控制采样:调节生成文本的随机性
条件生成:可扩展为带条件(情感、风格)的生成
完整训练流程:包含优化器、学习率调度器

技术亮点

梯度裁剪:防止梯度爆炸
Dropout正则化:防止过拟合
困惑度评估:量化生成质量
多样性分析:n-gram统计

可扩展性

支持替换为WikiText-2、Penn Treebank等公开数据集
可集成注意力机制、Transformer架构
支持Beam搜索、Top-k采样等高级生成策略

http://www.lryc.cn/news/613190.html

相关文章:

  • 基于深度学习的污水新冠RNA测序数据分析系统
  • 进程Linux
  • TSMaster-C小程序使用
  • 深度学习之opencv篇
  • change和watch
  • GPT-5 将在周五凌晨1点正式发布,王炸模型将免费使用??
  • 16.Home-懒加载指令优化
  • [C++20]协程:语义、调度与异步 | Reactor 模式
  • 在 Linux 系统上安装 Docker 的步骤如下(以 Ubuntu/Debian为例)
  • 深度学习(1):pytorch
  • Android-Kotlin基础(Jetpack②-Data Binding)
  • 内存杀手机器:TensorFlow Lite + Spring Boot移动端模型服务深度优化方案
  • Bosco-and-Mancuso Filter for CFA Image Denoising
  • python函数--python010
  • Java NIO 核心原理与秋招高频面试题解析
  • MySQL 极简安装挑战:跨平台高效部署指南
  • 大数据中需要知道的监控页面端口号都有哪些
  • 【unity知识】unity使用AABB(轴对齐包围盒)和OBB(定向包围盒)优化碰撞检测
  • 单词的划分(动态规划)
  • OpenCV 图像处理基础操作指南(一)
  • 非化学冷却塔水处理解决方案:绿色工业时代的革新引擎
  • Android视图状态以及重绘
  • 如何将服务器中的Docker镜像批量导出?
  • uat是什么
  • SIP - Centos 7 搭建freeswitch服务器
  • Linux第一阶段练习
  • Microsoft Office PowerPoint 制作简单的游戏素材
  • Sklearn 机器学习 数据降维PCA 自己实现PCA降维算法
  • 智能升级革命:Deepoc具身模型开发板如何让传统除草机器人拥有“认知大脑”
  • 【智能协同云图库】第六期:基于 百度API 和 Jsoup 爬虫实现以图搜图