当前位置: 首页 > news >正文

利用微调的deberta-v3-large来预测情感分类

前言:

昨天我们讲述了怎么利用emotion数据集进行deberta-v3-large大模型的微调,那今天我们就来输入一些数据来测试一下,看看模型的准确率,为了方便起见,我直接用测试集的前十条数据

代码:

from transformers import AutoModelForSequenceClassification,AutoTokenizer
import torch
import numpytokenizer = AutoTokenizer.from_pretrained("deberta-v3-large")
model = AutoModelForSequenceClassification.from_pretrained("result/checkpoint-500",num_labels=6)raw_inputs = ["im feeling rather rotten so im not very ambitious right now","im updating my blog because i feel shitty","i never make her separate from me because i don t ever want her to feel like i m ashamed with her","i left with my bouquet of red and yellow tulips under my arm feeling slightly more optimistic than when i arrived","i was feeling a little vain when i did this one","i cant walk into a shop anywhere where i do not feel uncomfortable","i felt anger when at the end of a telephone call","i explain why i clung to a relationship with a boy who was in many ways immature and uncommitted despite the excitement i should have been feeling for g
etting accepted into the masters program at the university of virginia","i like to have the same breathless feeling as a reader eager to see what will happen next","i jest i feel grumpy tired and pre menstrual which i probably am but then again its only been a week and im about as fit as a walrus on vacation for thesummer"
]
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
print(outputs.logits.argmax(-1).numpy())output_tensor = torch.softmax(outputs.logits, dim=1)numpy.set_printoptions(suppress=True, precision=15)
print(output_tensor.detach().numpy())

标注结果:

[0 0 0 1 0 4 3 1 1 3]

测试结果:

[0 0 0 1 0 4 4 2 1 3]
[[0.99185866    0.0011510316  0.00038844926 0.0026896652  0.00296234010.00094986777][0.9918577     0.0011512033  0.00038886679 0.0026923663  0.00295853150.000951257  ][0.99185807    0.0011446937  0.00038163515 0.0026456509  0.00303544850.00093440723][0.00041773843 0.9972398     0.0014854104  0.0002909223  0.000362315240.00020376328][0.99185014    0.0011451623  0.00038086114 0.0026396883  0.00305240350.00093187904][0.015044774   0.0025362356  0.00041989447 0.015223678   0.950097140.016678285  ][0.11319714    0.030935207   0.007336047   0.3035547     0.475454330.069522515  ][0.0011094044  0.18334262    0.8081213     0.0011003793  0.00072979650.005596481  ][0.0004444314  0.9972433     0.0014491597  0.00028465112 0.000374119760.00020446534][0.00241266    0.00079152075 0.00092184055 0.9924028     0.00241092480.0010602956 ]]

结果对比:

除了第七、第八条数据错误外,其他的八条数据都是正确的

代码解释:

1、raw_inputs:用户输入的数据,这个地方你可以使用一个while循环,然后使用input来与用户进行交互,需要注意的是这个必须是一个数组,哪怕用户只输入了一句文本。

2、return_tensors="pt":表示tokenizer返回的是PyTorch格式的数据

3、argmax(-1):将logits属性中的浮点数张量沿着最后一个轴(即-1轴)进行argmax操作,从而找到该张量中最大值所对应的标签编号。

4、softmax(outputs.logits, dim=1):dim指沿着哪个维度计算softmax,通常指定为1,表示对每一行进行softmax操作。如果不指定,则默认在最后一维计算softmax。

5、numpy.set_printoptions(suppress=True, precision=15):使用 numpy.set_printoptions() 函数来设置打印选项,从而调整打印输出格式。其中,suppress 选项可以关闭科学计数法,precision 选项可以设置打印精度。

http://www.lryc.cn/news/157127.html

相关文章:

  • opencv旋转图像
  • 容器资料: Docker和Singularity
  • 如何确认linux的包管理器是yum还是apt,确认之后安装其他程序的时候就需要注意安装命令
  • 数据分享|R语言分析上海空气质量指数数据:kmean聚类、层次聚类、时间序列分析:arima模型、指数平滑法...
  • MySQL 8.0.34安装教程
  • 用通俗易懂的方式讲解大模型分布式训练并行技术:概述
  • NodeJS入门以及文件模块fs模块
  • springboot集成Elasticsearch7.16,使用https方式连接并忽略SSL证书
  • 【已解决】pycharm 突然每次点击都开新页面,关不掉怎么办?
  • AndroidStudio最下方显示不出来Terminal等插件
  • python基础操作笔记
  • c++ 学习 之 指针常量 和 常量指针
  • Redis未授权访问漏洞实战
  • 【web开发】2、css基础
  • 循迹小车原理介绍和代码示例
  • redis未授权访问
  • 【数学建模竞赛】优化类赛题常用算法解析
  • Python实现SSA智能麻雀搜索算法优化LightGBM回归模型(LGBMRegressor算法)项目实战
  • OpenCV(二十一):椒盐噪声和高斯噪声的产生
  • 【设计模式】Head First 设计模式——构建器模式 C++实现
  • 基于Python+Django深度学习的身份证识别考勤系统设计与实现
  • Unity控制程序退出
  • C++ using的多种用法
  • Java环境的安装
  • 【ES6】js中的__proto__和prototype
  • 工程项目管理系统源码-简洁+好用+全面-工程项目管理
  • 后端SpringBoot+前端Vue前后端分离的项目(二)
  • 【5】openGL使用宏和函数进行错误检测
  • STM32 CAN快速配置(HAL库版本)
  • 【文末送书】全栈开发流程——后端连接数据源(二)