当前位置: 首页 > news >正文

自然语言推断:微调BERT

微调BERT

自然语言推断任务设计了一个基于注意力的结构。现在,我们通过微调BERT来重新审视这项任务。自然语言推断是一个序列级别的文本对分类问题,而微调BERT只需要一个额外的基于多层感知机的架构,如下图中所示。

本节将下载一个预训练好的小版本的BERT,然后对其进行微调,以便在SNLI数据集上进行自然语言推断。

import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2lnpx.set_np()

加载预训练的BERT

原始的BERT模型有数以亿计的参数。在下面,我们提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip','225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip','c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,num_heads, num_layers, dropout, max_len, devices):data_dir = d2l.download_extract(pretrained_model)# 定义空词表以加载预定义词表vocab = d2l.Vocab()vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,num_heads=4, num_layers=2, dropout=0.2,max_len=max_len, key_size=256, query_size=256,value_size=256, hid_in_features=256,mlm_in_features=256, nsp_in_features=256)# 加载预训练BERT参数bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))return bert, vocab

为了便于在大多数机器上演示,我们将在本节中加载和微调经过预训练BERT的小版本(“bert.small”)。在练习中,我们将展示如何微调大得多的“bert.base”以显著提高测试精度。

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model('bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,num_layers=2, dropout=0.1, max_len=512, devices=devices)

 

 

 

http://www.lryc.cn/news/288655.html

相关文章:

  • 立创EDA学习:设计收尾工作
  • ShardingSphere之ShardingJDBC客户端分库分表上
  • rust for循环步长-1,反向逆序遍历
  • 编译与运行环境(C语言)
  • 再谈Android View绘制流程
  • 分布式定时任务系列8:XXL-job源码分析之远程调用
  • python+Qt5 UOS 摄相头+麦克风测试,摄相头自动解析照片二维条码,麦克风解析音频文件
  • MongoDB日期存储与查询、@Query、嵌套字段查询实战总结
  • Windows版本Node.js常见问题及操作解决方式(小白入门必备)
  • 09.Elasticsearch应用(九)
  • ROS2常用命令工具
  • Linux之快速入门
  • C语言——操作符详解1
  • C++学习| QT快速入门
  • Android App开发-简单控件(1)——文本显示
  • [GYCTF2020]Ezsqli1
  • 【npm包】如何发布自己的npm包
  • 《WebKit技术内幕》学习之十五(2):Web前端的未来
  • 【教学类-综合练习-11】20240116 大4班 最后一次
  • 【阻塞队列】阻塞队列的模拟实现及在生产者和消费者模型上的应用
  • Cocos Creator使用VS Code调试代码配置
  • 【投稿优惠|EI优质会议】2024年材料化学与清洁能源国际学术会议(IACMCCE 2024)
  • ubuntu设置右键打开terminator、code
  • PHP AES加解密:用代码为数据加上保护的盾牌
  • Socket实现服务器和客户端
  • 智能GPT图书管理系统(SpringBoot2+Vue2)、接入GPT接口,支持AI智能图书馆
  • 面试经典 150 题 ---- 合并两个有序数组
  • 防火墙在企业园区出口安全方案中的应用(ENSP实现)
  • 单片机学习笔记---矩阵键盘密码锁
  • 8-小程序数据promise化、共享、分包