当前位置: 首页 > news >正文

21、Transformer Masked loss原理精讲及其PyTorch逐行实现

1. Transformer结构图

在这里插入图片描述

2. python

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0batch_size = 2seq_length = 3vocab_size = 4logits = torch.randn(batch_size,seq_length,vocab_size)print(f"logits=\n{logits}")logits_t = logits.transpose(-1,-2)print(f"logits_t=\n{logits_t}")label = torch.randint(0,vocab_size,(batch_size,seq_length))print(f"label=\n{label}")result_none = F.cross_entropy(logits_t,label,reduction="none")print(f"result_none=\n{result_none}")result_none_mean = torch.mean(result_none)result_mean = F.cross_entropy(logits_t,label)print(f"result_mean=\n{result_mean}")print(f"result_none_mean={result_none_mean}")
logits=
tensor([[[ 0.477,  2.017,  1.016, -0.299],[-0.189,  0.321, -0.885,  1.418],[ 0.027, -0.606,  0.079, -0.491]],[[ 1.911,  1.643, -0.327,  0.185],[-0.031, -1.463, -0.073,  1.391],[-0.710,  0.811,  1.521,  0.033]]])
logits_t=
tensor([[[ 0.477, -0.189,  0.027],[ 2.017,  0.321, -0.606],[ 1.016, -0.885,  0.079],[-0.299,  1.418, -0.491]],[[ 1.911, -0.031, -0.710],[ 1.643, -1.463,  0.811],[-0.327, -0.073,  1.521],[ 0.185,  1.391,  0.033]]])
label=
tensor([[0, 0, 0],[3, 0, 0]])
result_none=
tensor([[2.059, 2.098, 1.157],[2.444, 1.848, 2.832]])
result_mean=
2.0730881690979004
result_none_mean=2.0730881690979004
http://www.lryc.cn/news/520848.html

相关文章:

  • 构建高性能网络服务:从 Socket 原理到 Netty 应用实践
  • Spring Boot教程之五十六:用 Apache Kafka 消费 JSON 消息
  • Elasticsearch ES|QL 地理空间索引加入纽约犯罪地图
  • csp-j知识点:联合(Union)的基本概念
  • docker-compose 方式安装部署confluence
  • 深入理解计算机系统阅读笔记-第十二章
  • 网络原理(九):数据链路层 - 以太网协议 应用层 - DNS 协议
  • rtthread学习笔记系列(4/5/6/7/15/16)
  • 【拒绝算法PUA】3065. 超过阈值的最少操作数 I
  • 今日总结 2025-01-14
  • 关于扫描模型 拓扑 和 传递贴图工作流笔记
  • C#知识|泛型Generic概念与方法
  • centos 8 中安装Docker
  • vscode vue 自动格式化
  • Webpack 5 混淆插件terser-webpack-plugin生命周期作用时机和使用注意事项
  • MQTT(Message Queuing Telemetry Transport)协议
  • 【MySQL学习笔记】MySQL存储过程
  • Vue2+OpenLayers实现折线绘制、起始点标记和轨迹打点的完整功能(提供Gitee源码)
  • 基于Spring Boot的城市垃圾分类管理系统设计与实现(LW+源码+讲解)
  • linux: 文本编辑器vim
  • Eclipse Debug 调试
  • vue3+ts的<img :src=““ >写法
  • 《心血管成像的深度学习》论文精读
  • RDP、VNC、SSH 三种登陆方式的差异解析
  • 3d 可视化库 vister部署笔记
  • 操作系统八股文学习笔记
  • k8s基础(6)—Kubernetes-存储
  • K8S--配置存活、就绪和启动探针
  • 永久免费工业设备日志采集
  • 详解 Docker 启动 Windows 容器第二篇:技术原理与未来发展方向