当前位置: 首页 > news >正文

从零构建属于自己的GPT系列4:模型训练3(训练过程解读、序列填充函数、损失计算函数、评价函数、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

6 序列填充函数

def collate_fn(batch):input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=5)labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)return input_ids, labels

7 损失计算函数

def caculate_loss(logit, target, pad_idx, smoothing=True):if smoothing:logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2))target = target[..., 1:].contiguous().view(-1)eps = 0.1n_class = logit.size(-1)one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1)one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)log_prb = F.log_softmax(logit, dim=1)non_pad_mask = target.ne(pad_idx)loss = -(one_hot * log_prb).sum(dim=1)loss = loss.masked_select(non_pad_mask).mean()  # average laterelse:# loss = F.cross_entropy(predict_logit, target, ignore_index=pad_idx)logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))labels = target[..., 1:].contiguous().view(-1)loss = F.cross_entropy(logit, labels, ignore_index=pad_idx)return loss

8 评价函数

def calculate_acc(logit, labels, ignore_index=-100):logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))labels = labels[..., 1:].contiguous().view(-1)_, logit = logit.max(dim=-1)  # 对于每条数据,返回最大的index# 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1non_pad_mask = labels.ne(ignore_index)n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item()n_word = non_pad_mask.sum().item()return n_correct, n_word

9 训练过程解读

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

http://www.lryc.cn/news/256151.html

相关文章:

  • 光学遥感显著目标检测初探笔记总结
  • HttpComponents: 领域对象的设计
  • 使用wire重构商品微服务
  • 大三上实训内容
  • IOT安全学习路标
  • java中线程的状态是如何转换的?
  • 处理合并目录下的Excel文件数据并指定列去重
  • Numpy数组的去重 np.unique()(第15讲)
  • ROS-log功能区别
  • 学习git后,真正在项目中如何使用?
  • Qt国际化翻译Linguist使用
  • ShardingSphere数据分片之分表操作
  • 基于ssm鲸落文化线上体验馆论文
  • LeetCode Hot100 131.分割回文串
  • SAP UI5 walkthrough step9 Component Configuration
  • 【数据结构和算法】--- 栈
  • CentOS7.0 下rpm安装MySQL5.5.60
  • 智慧能源:数字孪生压缩空气储能管控平台
  • 【链表OJ—反转链表】
  • TCP一对一聊天
  • 基于Java的招聘系统的设计与实现
  • spring boot整合mybatis进行部门管理管理的增删改查
  • 微软 Power Platform 零基础 Power Pages 网页搭建高阶实际案例实践(四)
  • 如何在任何STM32上面安装micro_ros
  • 肖sir__ 项目讲解__项目数据
  • 微服务实战系列之J2Cache
  • 12.ROS导航模块:gmapping、AMCL、map_server、move_base案例
  • C++中string类的使用
  • LeeCode每日刷题12.8
  • 硕士毕业论文格式修改要点_word