当前位置: 首页 > news >正文

【文本分类】bert二分类

import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm# 自定义数据集
class CustomDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer(text,max_length=self.max_length,padding="max_length",truncation=True,return_tensors="pt")return {"input_ids": encoding["input_ids"].squeeze(0),"attention_mask": encoding["attention_mask"].squeeze(0),"label": torch.tensor(label, dtype=torch.long)}# 训练函数
def train_model(model, train_loader, optimizer, device, num_epochs=3):model.train()for epoch in range(num_epochs):total_loss = 0for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["label"].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")# 评估函数
def evaluate_model(model, val_loader, device):model.eval()predictions, true_labels = [], []with torch.no_grad():for batch in val_loader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["label"].to(device)outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitspreds = torch.argmax(logits, dim=1).cpu().numpy()predictions.extend(preds)true_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(true_labels, predictions)report = classification_report(true_labels, predictions)print(f"Validation Accuracy: {accuracy}")print("Classification Report:")print(report)# 模型保存函数
def save_model(model, tokenizer, output_dir):os.makedirs(output_dir, exist_ok=True)model.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)print(f"Model saved to {output_dir}")# 模型加载函数
def load_model(output_dir, device):tokenizer = BertTokenizer.from_pretrained(output_dir)model = BertForSequenceClassification.from_pretrained(output_dir)model.to(device)print(f"Model loaded from {output_dir}")return model, tokenizer# 推理预测函数
def predict(texts, model, tokenizer, device, max_length=128):model.eval()encodings = tokenizer(texts,max_length=max_length,padding="max_length",truncation=True,return_tensors="pt")input_ids = encodings["input_ids"].to(device)attention_mask = encodings["attention_mask"].to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitsprobabilities = torch.softmax(logits, dim=1).cpu().numpy()predictions = torch.argmax(logits, dim=1).cpu().numpy()return predictions, probabilities# 主函数
def main():# 配置参数config = {"train_batch_size": 16,"val_batch_size": 16,"learning_rate": 5e-5,"num_epochs": 5,"max_length": 128,"device_id": 7,  # 指定 GPU ID"model_dir": "model","local_model_path": "roberta_tiny_model",  # 指定本地模型路径,如果为 None 则使用预训练模型"pretrained_model_name": "uer/chinese_roberta_L-12_H-128",  # 预训练模型名称}# 设置设备device = torch.device(f"cuda:{config['device_id']}" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 加载分词器和模型tokenizer = BertTokenizer.from_pretrained(config["local_model_path"])model = BertForSequenceClassification.from_pretrained(config["local_model_path"], num_labels=2)model.to(device)# 示例数据train_texts = ["This is a great product!", "I hate this service."]train_labels = [1, 0]val_texts = ["Awesome experience.", "Terrible product."]val_labels = [1, 0]# 创建数据集和数据加载器train_dataset = CustomDataset(train_texts, train_labels, tokenizer, config["max_length"])val_dataset = CustomDataset(val_texts, val_labels, tokenizer, config["max_length"])train_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"])# 定义优化器optimizer = AdamW(model.parameters(), lr=config["learning_rate"])# 训练模型train_model(model, train_loader, optimizer, device, num_epochs=config["num_epochs"])# 评估模型evaluate_model(model, val_loader, device)# 保存模型save_model(model, tokenizer, config["model_dir"])# 加载模型loaded_model, loaded_tokenizer = load_model(config["model_dir"], "cpu")# 推理预测new_texts = ["I love this!", "It's the worst."]predictions, probabilities = predict(new_texts, loaded_model, loaded_tokenizer,  "cpu")for text, pred, prob in zip(new_texts, predictions, probabilities):print(f"Text: {text}")print(f"Predicted Label: {pred} (Probability: {prob})")if __name__ == "__main__":main()
http://www.lryc.cn/news/518036.html

相关文章:

  • 单例模式-如何保证全局唯一性?
  • 设计模式学习笔记——结构型模式
  • WEB攻防-通用漏洞_文件上传_黑白盒审计流程
  • RabbitMQ基本介绍及简单上手
  • 服务器证书不受信任是什么问题?
  • spring mvc源码学习笔记之十
  • Ubuntu 下载安装 elasticsearch7.17.9
  • Qt笔记:网络编程Tcp
  • C++单例模式跨DLL调用问题梳理
  • oracle闪回版本查询
  • C#用winform窗口程序操作服务+不显示Form窗体,只显示右下角托盘图标+开机时自启动程序【附带项目地址】
  • UOS系统和windows系统wps文档显示差异问题解决
  • JS中函数基础知识之查漏补缺(写给小白的学习笔记)
  • 蓝桥杯训练
  • 前端学习DAY33(外边距的折叠)
  • asp.net core mvc的 ViewBag , ViewData , Module ,TempData
  • Linux驱动学习之第二个驱动程序(LED点亮关闭驱动程序-分层设计思想,使其能适应不同的板子-驱动程序模块为多个源文件怎么写Makefile)
  • 手写@EnableTransactionalManagement
  • 【Vue】:解决动态更新 <video> 标签 src 属性后视频未刷新的问题
  • 网络基础1 http1.0 1.1 http/2的演进史
  • Python 通过命令行在 unittest.TestCase 中运行单元测试
  • 源代码编译安装X11及相关库、vim,配置vim(2)
  • 设计模式 行为型 观察者模式(Observer Pattern)与 常见技术框架应用 解析
  • 【25考研】川大计算机复试情况,重点是啥?怎么准备?
  • C#调用Lua
  • LeetCode100之组合总和(39)--Java
  • NTN学习笔记之术语和缩写词解析
  • Yolo11改进:注意力改进|Block改进|ESSAformer,用于高光谱图像超分辨率的高效Transformer|即插即用
  • STM32 单片机 练习项目 LED灯闪烁LED流水灯蜂鸣器 未完待续
  • 使用PyTorch实现基于稀疏编码的生成对抗网络(GAN)在CIFAR-10数据集上的应用