当前位置: 首页 > news >正文

智能聊天机器人:使用PyTorch构建多轮对话系统

使用PyTorch构建多轮对话系统的示例代码。这个示例项目包括一个简单的Seq2Seq模型用于对话生成,并使用GRU作为RNN的变体。以下是代码的主要部分,包括数据预处理、模型定义和训练循环。

数据预处理

首先,准备数据并进行预处理。这部分代码假定你有一个对话数据集,格式为成对的问答句子。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random# 假设数据集是一个成对的问答列表
pairs = [["Hi, how are you?", "I'm good, thank you! How about you?"],["What is your name?", "My name is Chatbot."],# 添加更多对话数据
]# 简单的词汇表和索引映射
word2index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
index2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
vocab_size = len(word2index)def tokenize(sentence):return sentence.lower().split()def build_vocab(pairs):global word2index, index2word, vocab_sizefor pair in pairs:for sentence in pair:for word in tokenize(sentence):if word not in word2index:word2index[word] = vocab_sizeindex2word[vocab_size] = wordvocab_size += 1def sentence_to_tensor(sentence):tokens = tokenize(sentence)indices = [word2index.get(word, word2index["<UNK>"]) for word in tokens]return torch.tensor(indices + [word2index["<EOS>"]], dtype=torch.long)build_vocab(pairs)

数据集和数据加载

定义一个Dataset类和DataLoader来加载数据。

class ChatDataset(Dataset):def __init__(self, pairs):self.pairs = pairsdef __len__(self):return len(self.pairs)def __getitem__(self, idx):input_tensor = sentence_to_tensor(self.pairs[idx][0])target_tensor = sentence_to_tensor(self.pairs[idx][1])return input_tensor, target_tensordef collate_fn(batch):inputs, targets = zip(*batch)input_lengths = [len(seq) for seq in inputs]target_lengths = [len(seq) for seq in targets]inputs = nn.utils.rnn.pad_sequence(inputs, padding_value=word2index["<PAD>"])targets = nn.utils.rnn.pad_sequence(targets, padding_value=word2index["<PAD>"])return inputs, targets, input_lengths, target_lengthsdataset = ChatDataset(pairs)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)

模型定义

定义一个简单的Seq2Seq模型,包括编码器和解码器。

class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers=1):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, enforce_sorted=False)outputs, hidden = self.gru(packed, hidden)outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_size, hidden_size, num_layers=1):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_step, hidden, encoder_outputs):embedded = self.embedding(input_step)gru_output, hidden = self.gru(embedded, hidden)output = self.softmax(self.out(gru_output.squeeze(0)))return output, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):batch_size = input_tensor.size(1)max_target_len = max(target_lengths)vocab_size = self.decoder.out.out_featuresoutputs = torch.zeros(max_target_len, batch_size, vocab_size).to(self.device)encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)decoder_input = torch.tensor([[word2index["<SOS>"]] * batch_size]).to(self.device)decoder_hidden = encoder_hiddenfor t in range(max_target_len):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)outputs[t] = decoder_outputtop1 = decoder_output.argmax(1)decoder_input = target_tensor[t].unsqueeze(0) if random.random() < teacher_forcing_ratio else top1.unsqueeze(0)return outputsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(vocab_size, hidden_size=256).to(device)
decoder = Decoder(vocab_size, hidden_size=256).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

训练循环

定义训练循环并进行模型训练。

def train(model, dataloader, num_epochs, learning_rate=0.001):criterion = nn.CrossEntropyLoss(ignore_index=word2index["<PAD>"])optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):model.train()total_loss = 0for inputs, targets, input_lengths, target_lengths in dataloader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs, targets, input_lengths, target_lengths)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader)}")train(model, dataloader, num_epochs=10)

测试与推理

定义一个简单的推理函数来进行对话生成。

def evaluate(model, sentence, max_length=10):model.eval()with torch.no_grad():input_tensor = sentence_to_tensor(sentence).unsqueeze(1).to(device)input_length = [input_tensor.size(0)]encoder_outputs, encoder_hidden = model.encoder(input_tensor, input_length)decoder_input = torch.tensor([[word2index["<SOS>"]]]).to(device)decoder_hidden = encoder_hiddendecoded_words = []for _ in range(max_length):decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden, encoder_outputs)top1 = decoder_output.argmax(1).item()if top1 == word2index["<EOS>"]:breakelse:decoded_words.append(index2word[top1])decoder_input = torch.tensor([[top1]]).to(device)return ' '.join(decoded_words)print(evaluate(model, "Hi, how are you?"))

总结

这只是一个简单的示例,用于展示如何使用PyTorch构建一个基本的多轮对话系统。实际应用中,可能需要更多的数据预处理、更复杂的模型(如Transformer)、更细致的训练策略和优化技术,以及更丰富的对话数据集。希望这个示例对你有所帮助!

http://www.lryc.cn/news/395489.html

相关文章:

  • 昇思25天学习打卡营第16天 | 文本解码原理-以MindNLP为例
  • Unity之Text组件换行\n没有实现+动态中英互换
  • vue3+ el-tree 展开和折叠,默认展开第一项
  • ProFormList --复杂数据联动ProFormDependency
  • Git、Github、tortoiseGit下载安装调试全套教程
  • 老师怎么快速发布成绩?
  • 央视揭露:上百元的AI填报高考志愿真的靠谱吗?阿里云新增两位AI圈“代言人”!|AI日报
  • TPM管理咨询公司甄选指南
  • 探索 Scikit-Learn:机器学习的强大工具库
  • 音视频质量评判标准
  • 如何在vue3中使用scss
  • Gartner发布采用美国防部模型实施零信任的方法指南:七大支柱落地方法
  • Flutter——最详细(Badge)使用教程
  • SQLServer的系统数据库用别的服务器上的系统数据库替换后做跨服务器连接时出现凭证、非对称金钥或私密金钥的资料无效
  • vue前端面试
  • 【网络安全】Host碰撞漏洞原理+工具+脚本
  • unattended-upgrade进程介绍
  • SpringBoot 中多例模式的神秘世界:用法区别以及应用场景,最后的灵魂拷问会吗?- 第519篇
  • 基于STM32设计的智能婴儿床(ESP8266局域网)_2024升级版_180
  • C++(第四天----拷贝函数、类的组合、类的继承)
  • 第一课:接口配置IP地址:DHCP模式
  • esp32_spfiffs
  • 每日一练全新考试模式解锁|考试升级
  • pyqt5图片分辨率导致的界面过大的问题
  • (三)前端javascript中的数据结构之集合
  • VuePress 的更多配置
  • 问题解决|Python 代码的组织形式与编码规范
  • Flask项目搭建及部署 —— Python
  • 【C++报错已解决】Invalid Use of ‘this’ Pointer
  • 群晖NAS配置WebDav服务结合内网穿透实现跨平台云同步思源笔记