当前位置: 首页 > news >正文

learn掩码张量

目录

1、什么是掩码张量

2、掩码张量的作用

3、代码演示

(1)、定义一个上三角矩阵,k=0或者 k默认为 0

(2)、k=1

(3)、k=-1

4、掩码张量代码实现

(1)、输出效果

(2)、输出效果分析


1、什么是掩码张量

  • 掩就是代表遮掩,码就是张量中的数值,它的尺寸不定,里面只有 1 和 0 的元素,代表的位置被遮掩或者不被遮掩,至于是 0 位置被遮掩还是 1 位置被遮掩可以自己定义,因此它的作用就是让另外一个张量中的数值被遮掩,也可以说成是被替换,它的表现形式是一个张量

2、掩码张量的作用

  • 在transformers中,掩码张量的主要作用应用在 attention时,有一些生成的attention张量中的值计算有可能已知了未来信息而得到的,未来信息被看到是因为训练时会把整个输出结果都一次性进行 Embedding,但是理论上解码器的输出却不是一次就能产生最终结果的,而是一次次通过上次结果综合得出的。因此,未来的信息可能被提前利用,所以,我们会进行遮掩

3、代码演示

(1)、定义一个上三角矩阵,k=0或者 k默认为 0

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = 0).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[1 1 1]
  [0 1 1]
  [0 0 1]]]

(2)、k=1

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[0 1 1]
  [0 0 1]
  [0 0 0]]]

(3)、k=-1

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = -1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[1 1 1]
  [1 1 1]
  [0 1 1]]]

4、掩码张量代码实现

import numpy as np
import torch
def subsequent_mask(size):""":param size: 生成向后遮掩的掩码张量,参数 size 是掩码张量的最后两个维度大小,它的最后两个维度形成一个方阵:return:"""attn_shape = (1,size,size) # 定义掩码张量的形状subsequent_mask = np.triu(np.ones(attn_shape),k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形return torch.from_numpy(1 - subsequent_mask) # 先将numpy 类型转化为 tensor,再做三角的翻转,将位置为 0 的地方变为 1,将位置为 1 的方变为 0
size = 5
sm = subsequent_mask(size)
print("sm :",sm)
# 掩码张量的可视化
import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])

(1)、输出效果

(2)、输出效果分析

  • 通过观察可视化方阵,黄色是 1 的部分,这里代表被遮掩,紫色代表没有被遮掩的信息,横坐标代表目标词汇的位置,纵坐标代表可查看的位置
  • 我们看到,在 0 的位置我们以看望过去都是黄色的,都被遮掩了,1的位置一眼望过去还是黄色,说明第一次词还没有产生,从第二个位置看过去,就能看到位置 1 的词,其他位置看不到,以此类推

http://www.lryc.cn/news/182150.html

相关文章:

  • 激活函数介绍
  • docker方式启动一个java项目-Nginx本地有代码,并配置反向代理
  • 前端和后端是Web开发选哪个好?
  • HTTP协议,请求响应
  • idea配置文件属性提示消息解决方案
  • EdgeView 4 for Mac:重新定义您的图像查看体验
  • 流程自动化(RPA)的好处有哪些?
  • 医学影像系统【简称PACS】源码
  • 大家都在用哪些敏捷开发项目管理软件?
  • python机器学习基础教程01-环境搭建
  • TinyWebServer学习笔记-Config
  • 数据结构与算法--算法
  • JVM:如何通俗的理解并发的可达性分析
  • 传统机器学习聚类算法——总集篇
  • Ajax
  • SQL_ERROR_INFO: “Duplicate entry ‘9003‘ for key ‘examination_info.exam_id‘“
  • 解决每次重启ganache虚拟环境,十个账号秘钥都会改变问题
  • sheng的学习笔记-【中文】【吴恩达课后测验】Course 2 - 改善深层神经网络 - 第一周测验
  • (粗糙的笔记)动态规划
  • Kaggle - LLM Science Exam上:赛事概述、数据收集、BERT Baseline
  • 数据分析三剑客之一:Numpy详解及实战
  • 【C语言】函数的定义、传参与调用(二)
  • Sentinel安装
  • 【JVM】并发可达性分析-三色标记算法
  • 黑豹程序员-架构师学习路线图-百科:Git/Gitee(版本控制)
  • 《Jetpack Compose从入门到实战》第一章 全新的 Android UI 框架
  • 基于Spring Boot的中小型医院网站的设计与实现
  • uniapp iOS离线打包——如何创建App并提交版本审核?
  • 论文笔记:Contrastive Trajectory Similarity Learning withDual-Feature Attention
  • 整数和字符串比较的坑