当前位置: 首页 > news >正文

实现pytorch注意力机制-one demo

主要组成部分:

1. 定义注意力层

定义一个Attention_Layer类,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

2. 定义前向传播:

定义了注意力层的前向传播过程,包括计算注意力权重和输出。

3. 数据准备

生成一个随机的数据集,包含3个句子,每个句子10个词,每个词128个特征。

4. 实例化注意力层:

实例化一个注意力层,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

5. 前向传播

将数据传递给注意力层的前向传播方法。

6. 分析结果

获取第一个句子的注意力权重。

7. 可视化注意力权重

使用matplotlib库可视化了注意力权重。

**主要函数和类:**
Attention_Layer类:定义了注意力层的结构和前向传播过程。
forward方法:定义了注意力层的前向传播过程。
torch.from_numpy函数:将numpy数组转换为PyTorch张量。
matplotlib库:用于可视化注意力权重。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt# 定义注意力层
class Attention_Layer(nn.Module):def __init__(self, hidden_dim, is_bi_rnn):super(Attention_Layer,self).__init__()self.hidden_dim = hidden_dimself.is_bi_rnn = is_bi_rnnif is_bi_rnn:self.Q_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)self.K_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)self.V_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)else:self.Q_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)self.K_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)self.V_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)def forward(self, inputs, lens):# 获取输入的大小size = inputs.size()Q = self.Q_linear(inputs) K = self.K_linear(inputs).permute(0, 2, 1)V = self.V_linear(inputs)max_len = max(lens)sentence_lengths = torch.Tensor(lens)mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]mask = mask.unsqueeze(dim = 1)mask = mask.expand(size[0], max_len, max_len)padding_num = torch.ones_like(mask)padding_num = -2**31 * padding_num.float()alpha = torch.matmul(Q, K)alpha = torch.where(mask, alpha, padding_num)alpha = F.softmax(alpha, dim = 2)out = torch.matmul(alpha, V)return out# 准备数据
data = np.random.rand(3, 10, 128)  # 3个句子,每个句子10个词,每个词128个特征
lens = [7, 10, 4]  # 每个句子的长度# 实例化注意力层
hidden_dim = 64
is_bi_rnn = True
att_L = Attention_Layer(hidden_dim, is_bi_rnn)# 前向传播
att_out = att_L(torch.from_numpy(data).float(), lens)# 分析结果
attention_weights = att_out[0, :, :].detach().numpy()  # 获取第一个句子的注意力权重# 可视化注意力权重
plt.imshow(attention_weights, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

在这里插入图片描述

http://www.lryc.cn/news/538825.html

相关文章:

  • 深入Flask:如何优雅地处理HTTP请求与响应
  • JVM ②-双亲委派模型 || 垃圾回收GC
  • jQuery介绍(快速、简洁JavaScript库,诞生于2006年,主要目标是简化HTML文档操作、事件处理、动画和Ajax交互)
  • python旅游推荐系统+爬虫+可视化(协同过滤算法)
  • Ubuntu 22.04.5 LTS 安装企业微信,(2025-02-17安装可行)
  • 【Excel笔记_6】条件格式和自定义格式设置表中数值超过100保留1位,超过1000保留0位,低于100为默认
  • UDP与TCP
  • Web开发技术概述
  • 解压rar格式的软件有哪些?8种方法(Win/Mac/手机/网页端)
  • uniapp开发:首次进入 App 弹出隐私协议窗口
  • 执行pnpm run dev报错:node:events:491 throw er; // Unhandled ‘error‘ event的解决方案
  • OpenCV机器学习(4)k-近邻算法(k-Nearest Neighbors, KNN)cv::ml::KNearest类
  • JVM中的线程池详解:原理→实践
  • SNARKs 和 UTXO链的未来
  • JavaScript设计模式 -- 外观模式
  • 百达翡丽(Patek Philippe):瑞士制表的巅峰之作(中英双语)
  • 阿里云一键部署DeepSeek-V3、DeepSeek-R1模型
  • 分享一款AI绘画图片展示和分享的小程序
  • 【练习】【双指针】力扣热题100 283. 移动零
  • QT 互斥锁
  • 什么是算法的空间复杂度和时间复杂度,分别怎么衡量。
  • VMware Workstation 17.0 Pro创建虚拟机并安装Ubuntu22.04与ubuntu20.04(双版本同时存在)《包含小问题总结》
  • Windows 10 ARM工控主板CAN总线实时性能测试
  • 如何在不依赖函数调用功能的情况下结合工具与大型语言模型
  • 【Linux AnolisOS】关于Docker的一系列问题。尤其是拉取东西时的网络问题,镜像源问题。
  • 【Elasticsearch】Mapping概述
  • GPT-4o悄然升级:能力与个性双突破,AI竞技场再掀波澜
  • 如何选择合适的超参数来训练Bert和TextCNN模型?
  • C# SpinLock 类 使用详解
  • 【linux】在 Linux 上部署 DeepSeek-r1:32/70b:解决下载中断问题