当前位置: 首页 > news >正文

【AI大模型】BERT微调文本分类任务实战

本文将详细指导你如何使用BERT模型微调进行文本分类任务,涵盖从环境配置到模型部署的完整流程。

环境配置

首先安装必要的库:

pip install transformers[torch] datasets pandas numpy scikit-learn matplotlib wandb

完整代码实现

import torch
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from transformers import (BertTokenizer,BertForSequenceClassification,TrainingArguments,Trainer,EarlyStoppingCallback
)
import matplotlib.pyplot as plt
import wandb# 初始化Weights & Biases(可选)
wandb.init(project="bert-text-classification", name="bert-base-uncased")# 1. 数据集准备
def load_custom_dataset():"""加载自定义数据集"""# 示例:加载CSV文件(实际使用时替换为你的数据路径)df = pd.read_csv("your_dataset.csv")# 数据集应包含'text'和'label'列# 确保标签为整数格式(0,1,2,...)# 划分训练集、验证集、测试集train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)# 转换为Hugging Face数据集格式train_dataset = Dataset.from_pandas(train_df)val_dataset = Dataset.from_pandas(val_df)test_dataset = Dataset.from_pandas(test_df)return train_dataset, val_dataset, test_dataset# 加载公开数据集(示例使用IMDB影评)
def load_public_dataset():dataset = load_dataset("imdb")return dataset["train"], dataset["test"], dataset["unsupervised"]  # 使用unsupervised作为验证集# 选择数据集来源
# train_dataset, val_dataset, test_dataset = load_custom_dataset()
train_dataset, test_dataset, val_dataset = load_public_dataset()# 2. 数据预处理
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)# 确定类别数量(从数据集中获取)
num_labels = len(set(train_dataset["label"]))def preprocess_function(examples):"""预处理函数:分词、截断、填充"""return tokenizer(examples["text"],max_length=256,truncation=True,padding="max_length",return_tensors="pt")# 应用预处理
encoded_train = train_dataset.map(preprocess_function, batched=True)
encoded_val = val_dataset.map(preprocess_function, batched=True)
encoded_test = test_dataset.map(preprocess_function, batched=True)# 3. 模型初始化
model = BertForSequenceClassification.from_pretrained(model_name,num_labels=num_labels,output_attentions=False,output_hidden_states=False
)# 4. 训练参数配置
training_args = TrainingArguments(output_dir="./results",          # 输出目录num_train_epochs=3,              # 训练轮数per_device_train_batch_size=16,  # 训练批次大小per_device_eval_batch_size=64,   # 评估批次大小learning_rate=2e-5,              # 学习率weight_decay=0.01,               # 权重衰减evaluation_strategy="epoch",     # 每轮评估save_strategy="epoch",           # 每轮保存logging_dir="./logs",            # 日志目录logging_steps=100,               # 每100步记录日志load_best_model_at_end=True,     # 训练结束时加载最佳模型metric_for_best_model="f1",      # 使用F1分数选择最佳模型report_to="wandb",               # 报告到Weights & Biasesfp16=True,                       # 使用混合精度训练(如果GPU支持)
)# 5. 评估指标计算
def compute_metrics(p):"""计算评估指标"""predictions, labels = ppredictions = np.argmax(predictions, axis=1)acc = accuracy_score(labels, predictions)f1 = f1_score(labels, predictions, average="weighted")# 完整分类报告(可选)if len(set(labels)) <= 10:  # 类别较少时显示完整报告print("\nClassification Report:")print(classification_report(labels, predictions))return {"accuracy": acc, "f1": f1}# 6. 训练器设置
trainer = Trainer(model=model,args=training_args,train_dataset=encoded_train,eval_dataset=encoded_val,compute_metrics=compute_metrics,callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # 早停策略
)# 7. 模型训练
print("开始训练BERT模型...")
train_result = trainer.train()# 保存训练指标
metrics = train_result.metrics
trainer.save_metrics("train", metrics)
trainer.save_model("./best_model")# 8. 模型评估
print("\n在测试集上评估模型...")
test_metrics = trainer.evaluate(encoded_test)
print(f"测试集性能: {test_metrics}")# 9. 可视化训练过程
def plot_training_metrics(log_history):"""绘制训练指标图表"""train_loss, eval_loss, eval_f1 = [], [], []steps = []for entry in log_history:if "loss" in entry and "epoch" in entry:train_loss.append(entry["loss"])steps.append(entry["step"])elif "eval_loss" in entry:eval_loss.append(entry["eval_loss"])eval_f1.append(entry["eval_f1"])plt.figure(figsize=(12, 10))# 训练损失plt.subplot(2, 1, 1)plt.plot(steps[:len(train_loss)], train_loss, 'b-', label="Training Loss")plt.plot(steps[len(train_loss)-len(eval_loss):], eval_loss, 'r-', label="Validation Loss")plt.title("Training & Validation Loss")plt.xlabel("Training Steps")plt.ylabel("Loss")plt.legend()plt.grid(True)# F1分数plt.subplot(2, 1, 2)plt.plot(steps[len(train_loss)-len(eval_f1):], eval_f1, 'g-')plt.title("Validation F1 Score")plt.xlabel("Training Steps")plt.ylabel("F1 Score")plt.grid(True)plt.tight_layout()plt.savefig("./training_metrics.png")plt.show()# 绘制指标图表
plot_training_metrics(trainer.state.log_history)# 10. 模型推理示例
def predict(text):"""使用训练好的模型进行预测"""inputs = tokenizer(text,max_length=256,truncation=True,padding="max_length",return_tensors="pt")# 移动到GPU(如果可用)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")inputs = {k: v.to(device) for k, v in inputs.items()}model.to(device)# 预测with torch.no_grad():outputs = model(**inputs)# 获取预测结果logits = outputs.logitsprobabilities = torch.softmax(logits, dim=1).cpu().numpy()[0]predicted_class = torch.argmax(logits, dim=1).item()return predicted_class, probabilities# 示例预测
sample_text = "This movie was absolutely fantastic! The acting was superb and the storyline captivating."
predicted_class, probabilities = predict(sample_text)
print(f"\n示例文本: '{sample_text}'")
print(f"预测类别: {predicted_class}")
print(f"类别概率: {probabilities}")# 完成Weights & Biases记录
wandb.finish()

关键步骤详解

1. 数据集准备

  • 支持加载自定义CSV数据集(需包含"text"和"label"列)

  • 也支持加载Hugging Face公开数据集(如IMDB)

  • 自动划分训练集(70%)、验证集(15%)、测试集(15%)

2. 数据预处理

  • 使用BERT的分词器(Tokenizer)处理文本

  • 设置最大长度256(可根据需求调整)

  • 自动截断和填充保证统一长度

  • 将文本转换为模型可接受的输入格式

3. 模型初始化

  • 加载预训练的bert-base-uncased模型

  • 添加分类头部,输出维度等于类别数量

  • 自动从数据集中推断类别数量

4. 训练配置

  • 学习率:2e-5(BERT微调的常用学习率)

  • 批次大小:训练16/评估64(根据GPU显存调整)

  • 训练轮数:3(可增加至5-10轮以获得更好效果)

  • 早停机制:验证集性能连续2轮不提升时停止训练

5. 评估指标

  • 主要评估指标:准确率(Accuracy)和F1分数

  • 输出完整分类报告(当类别数≤10时)

  • 支持多类和多标签分类任务

6. 训练过程可视化

  • 实时记录训练损失和验证损失

  • 绘制训练过程中的指标变化

  • 支持Weights & Biases在线监控(可选)

7. 模型部署与推理

  • 保存最佳模型到./best_model目录

  • 提供predict()函数用于单文本预测

  • 输出预测类别及各类别概率

优化建议

性能提升技巧

  1. 学习率调度:添加学习率warmup和余弦衰减

    warmup_ratio=0.1,
    lr_scheduler_type="cosine",

  2. 分层学习率:BERT底层使用较小学习率

    optimizer = AdamW([{"params": model.bert.parameters(), "lr": 1e-5},{"params": model.classifier.parameters(), "lr": 2e-5}]
    )

  3. 数据增强:提升小数据集性能

    • 同义词替换

    • 随机插入/删除

    • 回译(Back Translation)

处理不平衡数据

from torch import nn# 计算类别权重
class_counts = np.bincount(train_dataset["label"])
class_weights = 1. / class_counts
class_weights = torch.tensor(class_weights, dtype=torch.float32)# 自定义损失函数
class WeightedCrossEntropyLoss(nn.Module):def __init__(self, weights):super().__init__()self.weights = weightsdef forward(self, inputs, targets):ce_loss = nn.CrossEntropyLoss(reduction="none")(inputs, targets)weights = self.weights[targets]return (ce_loss * weights).mean()# 在Trainer中设置
trainer = Trainer(...compute_metrics=compute_metrics,callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],loss_function=WeightedCrossEntropyLoss(class_weights)
)

常见问题解决

内存不足问题

  1. 减小批次大小(per_device_train_batch_size

  2. 使用梯度累积:

    gradient_accumulation_steps=4

  3. 启用梯度检查点:

    model.gradient_checkpointing_enable()

  4. 使用混合精度训练:

    fp16=True

过拟合问题

  1. 增加Dropout概率:

    model = BertForSequenceClassification.from_pretrained(model_name,num_labels=num_labels,hidden_dropout_prob=0.3,  # 默认0.1attention_probs_dropout_prob=0.3
    )

  2. 增加权重衰减:

    weight_decay=0.1

  3. 添加更多正则化技术:

    label_smoothing_factor=0.1

进阶扩展

多语言分类

# 使用多语言BERT模型
model_name = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

领域自适应

# 先继续预训练(MLM任务)在领域数据上
from transformers import BertForMaskedLMmlm_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# 在领域数据上训练MLM模型...
# 然后使用领域适应的权重初始化分类模型
model = BertForSequenceClassification.from_pretrained("./domain_adapted_mlm", num_labels=num_labels)

通过本指南,你可以高效地微调BERT模型解决各类文本分类问题。根据具体任务需求调整参数和数据处理方式,通常只需少量训练即可获得优异性能。

http://www.lryc.cn/news/585461.html

相关文章:

  • spring boot 详解以及原理
  • 力扣-141.环形链表
  • 力扣_二叉搜索树_python版本
  • leetcode 3169. 无需开会的工作日 中等
  • 在 Spring Boot 中优化长轮询(Long Polling)连接频繁建立销毁问题
  • 100G系列光模块产品与应用场景介绍
  • 7.12 卷积 | 最小生成树 prim
  • ICCV2025接收论文速览(1)
  • python之set详谈
  • 基于图神经网络的社交网络影响力预测模型
  • 【操作系统】 Linux 系统调用(一)
  • 线程通信与进程通信的区别笔记
  • c++11——左值、右值、完美转发、移动语义
  • 注意力机制十问
  • JavaAI时代:重塑企业级智能开发新范式
  • slam全局路径规划算法详解(Dijkstra、A*)
  • 【软考高项】信息系统项目管理师-第2章 信息技术发展(2.1 计算机软硬件)
  • Leaflet面试题及答案(21-40)
  • PLC框架-1.3- 汇川PN伺服(3号报文)
  • 全球化 2.0 | 印尼金融科技公司通过云轴科技ZStack实现VMware替代
  • 在HTML中CSS三种使用方式
  • Vue + Element UI 实现选框联动进而动态控制选框必填
  • WebSocket 重连与心跳机制:打造坚如磐石的实时连接
  • 千辛万苦3面却倒在性格测试?这太离谱了吧!
  • 飞算JavaAI:重塑Java开发的“人机协同“新模式
  • Mani-GS 运行指南
  • Cursor、飞算JavaAI、GitHub Copilot、Gemini CLI 等热门 AI 开发工具合集
  • django queryset 去重
  • Nginx服务器集群:横向扩展与集群解决方案
  • 巨人网络持续加强AI工业化管线,Lovart国内版有望协同互补