当前位置: 首页 > news >正文

简化的动态稀疏视觉Transformer的PyTorch代码

存一串代码(简化的动态稀疏视觉Transformer的PyTorch代码)


import torch 
import torch.nn  as nn 
import torch.nn.functional  as F class DynamicSparseAttention(nn.Module): def __init__(self, dim, num_heads=8, dropout=0.1): super().__init__() self.num_heads  = num_heads self.head_dim  = dim // num_heads self.scale  = self.head_dim  ** -0.5 self.qkv  = nn.Linear(dim, dim * 3, bias=False) self.attn_drop  = nn.Dropout(dropout) self.proj  = nn.Linear(dim, dim) self.proj_drop  = nn.Dropout(dropout) def forward(self, x): B, N, C = x.shape  qkv = self.qkv(x).reshape(B,  N, 3, self.num_heads,  self.head_dim).permute(2,  0, 3, 1, 4) q, k, v = qkv.unbind(0)  attn = (q @ k.transpose(-2,  -1)) * self.scale  attn = attn.softmax(dim=-1)  attn = self.attn_drop(attn)  x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x)  x = self.proj_drop(x)  return x class HierarchicalRoutingBlock(nn.Module): def __init__(self, dim, num_heads=8, mlp_ratio=4., dropout=0.1): super().__init__() self.norm1  = nn.LayerNorm(dim) self.attn  = DynamicSparseAttention(dim, num_heads, dropout) self.norm2  = nn.LayerNorm(dim) self.mlp  = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(dropout) ) def forward(self, x): x = x + self.attn(self.norm1(x))  x = x + self.mlp(self.norm2(x))  return x class DynamicSparseVisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, num_heads=8, depth=12, mlp_ratio=4., dropout=0.1): super().__init__() self.patch_embed  = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) self.pos_embed  = nn.Parameter(torch.zeros(1,  (img_size // patch_size) ** 2, dim)) self.dropout  = nn.Dropout(dropout) self.blocks  = nn.ModuleList([HierarchicalRoutingBlock(dim, num_heads, mlp_ratio, dropout) for _ in range(depth)]) self.norm  = nn.LayerNorm(dim) self.head  = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x).flatten(2).transpose(1,  2) x = x + self.pos_embed  x = self.dropout(x)  for blk in self.blocks:  x = blk(x) x = self.norm(x)  x = x[:, 0] x = self.head(x)  return x # 使用 
model = DynamicSparseVisionTransformer() 
x = torch.randn(1,  3, 224, 224) 
output = model(x) 
print(output.shape)  

代码解释
DynamicSparseAttention:实现动态稀疏注意力模块。
HierarchicalRoutingBlock:实现层次化路由块,包含注意力模块和多层感知机。
DynamicSparseVisionTransformer:实现完整的动态稀疏视觉Transformer模型,包括补丁嵌入、位置嵌入、层次化路由块和分类头。

http://www.lryc.cn/news/535503.html

相关文章:

  • PADS多层板减少层数
  • 你需要提供管理员权限才能删除此文件夹解决方法
  • 螺旋折线(蓝桥杯18G)
  • 常见的数据仓库有哪些?
  • 数据科学之数据管理|NumPy数据管
  • LSTM 学习笔记 之pytorch调包每个参数的解释
  • ASUS/华硕飞行堡垒9 FX506H FX706H 原厂Win10系统 工厂文件 带ASUS Recovery恢复
  • Unity使用iTextSharp导出PDF-04图形
  • JDBC如何连接数据库
  • Unity URP的2D光照简介
  • 【IC】AI处理器核心--第二部分 用于处理 DNN 的硬件设计
  • 从 0 开始本地部署 DeepSeek:详细步骤 + 避坑指南 + 构建可视化(安装在D盘)
  • 如何本地部署DeepSeek集成Word办公软件
  • Centos10 Stream 基础配置
  • 时间序列分析(三)——白噪声检验
  • ThinkPHP8视图赋值与渲染
  • 对贵司需求的PLC触摸的远程调试的解决方案
  • 2.12寒假作业
  • 记使用AScript自动化操作ios苹果手机
  • 【Apache Paimon】-- 16 -- 利用 paimon-flink-action 同步 kafka 数据到 hive paimon 表中
  • 基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
  • 算法之 数论
  • Java 大视界 -- 人工智能驱动下 Java 大数据的技术革新与应用突破(83)
  • 【04】RUST特性
  • PlantUml常用语法
  • 保存字典类型的文件用什么格式比较好
  • 开源模型应用落地-Qwen1.5-MoE-A2.7B-Chat与vllm实现推理加速的正确姿势(一)
  • 一竞技瓦拉几亚S4预选:YB 2-0击败GG
  • deepseek+kimi一键生成PPT
  • mybatis 是否支持延迟加载?延迟加载的原理是什么?