当前位置: 首页 > news >正文

【NLP入门系列四】评论文本分类入门案例

在这里插入图片描述

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

博主简介:努力学习的22级本科生一枚 🌟​;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒​🌸
博客主页:羊小猪~~-CSDN博客
内容简介:这一篇是NLP的入门项目,以AG_NEW新闻数据为例。
🌸箴言🌸:去寻找理想的“天空“”之城
上一篇内容:【NLP入门系列三】NLP文本嵌入(以Embedding和EmbeddingBag为例)-CSDN博客
​💁​​💁​​💁​​💁​: 如果在conda安装环境,由于nlp的核心包是torchtext,所以如果把握不好就重新创建一虚拟环境(小编的“难忘”经历)

文章目录

    • 1、准备
      • 数据加载
      • 构建词表
    • 2、生成数据批次和迭代器
    • 3、定义与模型
      • 模型定义
      • 创建模型
    • 4、创建训练和评估函数
      • 训练函数
      • 评估函数
      • 创建超参数
    • 5、模型训练
    • 6、结果展示
    • 7、预测

🤔 思路

在这里插入图片描述

1、准备

AG News 数据集(也叫 AG’s Corpus or AG News Dataset),这是一个广泛用于自然语言处理(NLP)任务中的文本分类数据集


基本信息:

  • 全称:AG News
  • 来源:来源于 AG’s corpus,由 A. Godin 在 2005 年构建。
  • 用途:主要用于短文本多类别分类任务
  • 语言:英文
  • 类别数:4 类新闻主题
  • 训练样本数:120,000 条
  • 测试样本数:7,600 条

类别标签(共 4 类)

标签含义
1World (世界)
2Sports (体育)
3Business (商业)
4Science and Technology (科技)

数据加载

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader 
import torchtext 
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
# 加载本地数据
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")# 合并标题和描述数据
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]# 查看数据格式
train_df
Class IndexTitleDescriptiontext
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...Wall St. Bears Claw Back Into the Black (Reute...
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...Carlyle Looks Toward Commercial Aerospace (Reu...
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\ab...Oil and Economy Cloud Stocks' Outlook (Reuters...
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\f...Iraq Halts Oil Exports from Main Southern Pipe...
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...Oil prices soar to all-time record, posing new...
...............
1199951Pakistan's Musharraf Says Won't Quit as Army C...KARACHI (Reuters) - Pakistani President Perve...Pakistan's Musharraf Says Won't Quit as Army C...
1199962Renteria signing a top-shelf dealRed Sox general manager Theo Epstein acknowled...Renteria signing a top-shelf deal Red Sox gene...
1199972Saban not going to Dolphins yetThe Miami Dolphins will put their courtship of...Saban not going to Dolphins yet The Miami Dolp...
1199982Today's NFL gamesPITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...Today's NFL games PITTSBURGH at NY GIANTS Time...
1199992Nets get Carter from RaptorsINDIANAPOLIS -- All-Star Vince Carter was trad...Nets get Carter from Raptors INDIANAPOLIS -- A...

120000 rows × 4 columns

构建词表

# 定义 Dataset
class AGNewsDataset(Dataset):def __init__(self, dataframe):self.labels = dataframe['Class Index'].tolist()  # 列表数据self.texts = dataframe['text'].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加载数据
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)# 构建词表
tokenizer = get_tokenizer("basic_english")  # 英文数据,设置英文分词def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)  # 构建词表# 构建词表,设置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Vocab size: 95804
# 查看这些单词所在词典的索引
vocab(['here', 'is', 'an', 'example'])  
[475, 21, 30, 5297]
'''  
标签,原始是字符串类型,现在要转换成 数字 类型
文本数字化,需要一个函数进行转换(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x))  # 先分词。在数字化
label_pipline = lambda x : int(x) - 1   # 标签转化为数字# 举例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]

2、生成数据批次和迭代器

# 采用embeddingbag嵌入方式,故需要构建数据,包括长度、标签、偏移量
''' 
数据格式:长度(~, 1)
标签:一维
偏移量:一维
'''
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 标签列表,注意字符串转换成数字label_list.append(label_pipline(_label))# 文本列表,注意要转入tensro数据temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)text_list.append(temp_text)# 偏移量offsets.append(temp_text.size(0))# 全部转变成tensor变量label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)# 数据加载
batch_size = 16
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)

3、定义与模型

模型定义

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embeddingBag = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,   # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化权重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange)  # 初始化权重范围self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置置为0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)

创建模型

# 查看数据类别
train_df.groupby('Class Index').count()
TitleDescriptiontext
Class Index
1300003000030000
2300003000030000
3300003000030000
4300003000030000
class_num = 4
vocab_len = len(vocab)
embed_dim = 64  # 嵌入到64维度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)

4、创建训练和评估函数

训练函数

def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求导与反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_loss

评估函数

def test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss

创建超参数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)  # 动态调整学习率

5、模型训练

import copyepochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 动态调整学习率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model)  # 保存模型# 当前学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,  epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型参数
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03

6、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()


在这里插入图片描述

7、预测

model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型评估# 测试句子
test_sentence = "This is a news about Technology"# 转换为 token
token_ids = vocab(tokenizer(test_sentence))   # 切割分词--> 词典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device)  # 转化为tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 测试,注意:不需要反向求导
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()# 输出结果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"预测类别: {class_names[predicted_label]}")
预测类别: Science and Technology
http://www.lryc.cn/news/579494.html

相关文章:

  • ubuntu 6.8.0 安装xenomai3.3
  • lspci查看PCI设备详细信息
  • OpenCV篇——项目(二)OCR文档扫描
  • Rust方法语法:赋予结构体行为的力量
  • ConcurrentHashMap 原理
  • Linux多线程(十二)之【生产者消费者模型】
  • 汽车ECU产线烧录和检测软件怎么做?
  • Flutter 3.29+使用isar构建失败
  • HarmonyOS ArkTS卡片堆叠滑动组件实战与原理详解(含源码)
  • Java网络编程:TCP/UDP套接字通信详解
  • I/O 进程 7.2
  • 在Ubuntu 24.04主机上创建Ubuntu 14.04编译环境的完整指南
  • (一)复习(模块注入/minimal api/EF和Dapper实现CQRS)
  • Ubuntu Gnome 安装和卸载 WhiteSur-gtk-theme 类 Mac 主题的正确方法
  • Frida:配置自动补全 in VSCode
  • TCP 三次握手与四次挥手详解
  • MyBatis 之基础概念与框架原理详解
  • RabbitMQ 通过HTTP API删除队列命令
  • 【如何判断Linux系统是Ubuntu还是CentOS】
  • Centrifugo 深度解析:构建高性能实时应用的开源引擎
  • 记忆翻牌记忆力小游戏流量主微信小程序开源
  • 网创vip课程视频教程、付费网络课程以及网赚培训,学习引流、建站、赚钱。8个T的全套课程
  • 【2.3 漫画SpringSecurity - 守护应用安全的钢铁卫士】
  • ATE FT ChangeKit学习总结-20250630
  • Easy-excel监听器中对批量上传的工单做错误收集
  • Redisson使用示例
  • 请求未达服务端?iOS端HTTPS链路异常的多工具抓包排查记录
  • 【Bug Recod】更新中...
  • Day50
  • 一文详解Character AI:实用指南+ ChatGPT、Gemini对比分析