当前位置: 首页 > news >正文

BERT模型基本原理及实现示例

  BERT(Bidirectional Encoder Representations from Transformers)是Google在2018年提出的预训练语言模型,其核心思想是通过双向Transformer结构捕捉上下文信息,为下游NLP任务提供通用的语义表示。

一、模型架构

  BERT基于Transformer的编码器(Encoder)堆叠而成,摒弃了解码器(Decoder)。每个Encoder层包含:

  自注意力机制(Self-Attention):计算输入序列中每个词与其他词的关系权重,动态聚合上下文信息。
前馈神经网络(FFN):对注意力输出进行非线性变换。
残差连接与层归一化:缓解深层网络训练中的梯度消失问题。

  与传统单向语言模型(如GPT)不同,BERT通过同时观察左右两侧的上下文(双向注意力)捕捉词语的完整语义。

二、预训练任务

  BERT通过以下两个无监督任务预训练模型:

1.遮蔽语言模型(Masked Language Model, MLM)

  训练过程中,输入句子中一部分词会被 (MASK) 标记替换,模型需根据上下文信息预测这些被遮蔽的词。这种任务迫使模型在训练时同时考虑文本前后信息,学习更丰富的语言表征。具体操作是,随机选择 15% 的词汇用于预测,其中 80% 情况下用 (MASK) 替换,10% 情况下用任意词替换,10% 情况下保持原词汇不变。

2.下一句预测(Next Sentence Prediction, NSP)

  旨在训练模型理解句子间的连贯性。训练时,模型接收一对句子作为输入,判断两个句子是否是连续的文本序列。通过该任务,模型能学习到句子乃至篇章层面的语义信息。在实际预训练中,会从文本语料库中随机选择 50% 正确语句对和 50% 错误语句对进行训练。

三、输入表示

  BERT的输入由三部分嵌入相加组成:

  Token Embeddings:词向量(WordPiece分词)。
Segment Embeddings:区分句子A和B(用于NSP任务)。
Position Embeddings:Transformer本身无位置感知,需显式加入位置编码。

四、微调(Fine-tuning)

  预训练后,BERT可通过简单的微调适配下游任务:

  分类任务(如情感分析):用(CLS)标记的输出向量接分类层。
序列标注(如NER):用每个Token的输出向量预测标签。
问答任务:用两个向量分别预测答案的起止位置。

  微调时只需添加少量任务特定层,大部分参数复用预训练模型。

五、Python实现示例

(环境:Python 3.11,paddle 1.0.2, paddlenlp 2.6.1)

import paddle
from paddlenlp.transformers import BertTokenizer, BertForSequenceClassification# 1. 加载预训练模型和分词器
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_classes=2)  # 2分类任务# 2. 准备示例数据 (正面和负面情感文本)
texts = ["这部电影太棒了,演员表演出色!", "非常糟糕的体验,完全不推荐。"]
labels = [1, 0]  # 1表示正面,0表示负面# 3. 数据预处理
encoded_inputs = tokenizer(texts, max_length=128, padding=True, truncation=True, return_tensors='pd')
input_ids = encoded_inputs['input_ids']
token_type_ids = encoded_inputs['token_type_ids']# 转换为Paddle张量
labels = paddle.to_tensor(labels)# 4. 模型前向计算
outputs = model(input_ids, token_type_ids=token_type_ids)
logits = outputs# 5. 计算损失和预测
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)# 获取预测结果
predictions = paddle.argmax(logits, axis=1)# 打印结果
print("Loss:", loss.item())
print("Predictions:", predictions.numpy())
print("True labels:", labels.numpy())

在这里插入图片描述



End.

http://www.lryc.cn/news/584745.html

相关文章:

  • 强化学习 MDP
  • 从代码生成到智能运维的革命性变革
  • 集成平台业务编排设计器
  • 在虚拟机中安装Linux系统
  • 下一代防火墙-终端安全防护
  • 【数据结构】顺序表(sequential list)
  • Python3邮件发送全指南:文本、HTML与附件
  • 力扣61.旋转链表
  • 【会员专享数据】2013-2024年我国省市县三级逐日SO₂数值数据(Shp/Excel格式)
  • 【Linux基础命令使用】VIM编辑器的使用
  • WinUI3入门17:本地文件存储LocalApplicationData在哪里
  • 企业数据开发治理平台选型:13款系统优劣对比
  • Building Bridges(搭建桥梁)
  • HVV注意事项(个人总结 非技术)
  • 在VMware中安装虚拟机
  • 数据结构 --- 队列
  • XCZU47DR-2FFVG1517I Xilinx FPGA AMD ZynqUltraScale+ RFSoC
  • 超声波刻刀适用于一些对切割精度要求高、材料厚度较薄或质地较软的场景,典型应用场景如下:
  • 测试开发和后端开发到底怎么选?
  • UGF开发记录_3_使用Python一键转换Excle表格为Txt文本
  • 穿梭时空的智慧向导:Deepoc具身智能如何赋予导览机器人“人情味”
  • Qt中处理多个同类型对象共享槽函数应用
  • 广州华锐互动在各领域打造的 VR 成功案例展示​
  • pycharm无法识别pip安装的包
  • 【佳易王中药材划价软件】:让中药在线管理高效化、复制文本即可识别金额自动计算#中药房管理工具#快速划价系统#库存与账单一体化解决方案,软件程序操作教程详解
  • 多线程 JAVA
  • MySQL索引操作全指南:创建、查看、优化
  • H5微应用四端调试工具—网页版:深入解析与使用指南
  • 7月10号总结 (1)
  • C++ Lambda 表达式详解