当前位置: 首页 > news >正文

利用前向勾子获取神经网络中间层的输出并将其进行保存(示例详解)

代码示例:

# 激活字典,用于保存每次的中间特征
activation = {}# 将 forward_hook 函数定义在 upsample_v2 外部
def forward_hook(name):def hook(module, input, output):activation[name] = output.detach()return hookdef upsample_v2(in_channels, out_channels, upscale, kernel_size=3):layers = []# Define mid channel stages (three times reduction)mid_channels = [256, 128, 64]  # 512 32 32 -> 256 64 64 -> 128 128 128 -> 64 256 256 -> 2 256 256scale_factor_per_step = upscale ** (1/3)  # Calculate the scaling for each stepcurrent_in_channels = in_channels# Upsample and reduce channels in 3 stepsfor step, mid_channel in enumerate(mid_channels):# Conv layer to reduce number of channelsconv = nn.Conv2d(current_in_channels, mid_channel, kernel_size=kernel_size, padding=1, bias=False)nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')layers.append(conv)# ReLU activationrelu = nn.ReLU()layers.append(relu)# Upsampling layerup = nn.Upsample(scale_factor=scale_factor_per_step, mode='bilinear', align_corners=True)layers.append(up)layers[-1].register_forward_hook(forward_hook(f'step_{step}'))# Update current in_channels for the next layercurrent_in_channels = mid_channelconv = nn.Conv2d(current_in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')layers.append(conv)return nn.Sequential(*layers)
def forward_hook(name):def hook(module, input, output):activation[name] = output.detach()return hook

forward_hook布置了抓取函数。其中,module代表你下面勾的那一层,input代表那一层的输入,output定义那一层的输出,我们常常只使用output。

layers[-1].register_forward_hook(forward_hook(f'step_{step}'))

这里定义了我需要捕获的那一层,layers[-1]代表我要捕获当前layers的最后一层,即上采用层,由于循环了三次,所以最后勾取的应当是三份中间层输出。

http://www.lryc.cn/news/469966.html

相关文章:

  • CTF-RE 从0到N: S盒
  • MT-Pref数据集:包含18种语言的18k实例,涵盖多个领域。实验表明它能有效提升Tower模型在WMT23和FLORES基准测试中的翻译质量。
  • 【C++ 真题】B2099 矩阵交换行
  • AAPL: Adding Attributes to Prompt Learning for Vision-Language Models
  • MySQLDBA修炼之道-开发篇(一)
  • Spring MVC 知识点全解析
  • python 基于FastAPI实现一个简易的在线用户统计 服务
  • glibc中xdr的一个bug
  • Android Framework定制sim卡插入解锁pin码的界面
  • cc2530 Basic RF 讲解 和点灯讲解(1_1)
  • Android H5页面性能分析策略
  • 【前端面试】Typescript
  • 程序语言的内存管理:垃圾回收GC(Java)、手动管理(C语言)与所有权机制(Rust)(手动内存管理、手动管理内存)
  • 研究生论文学习记录
  • 毕业设计选题:基于Django+Vue的图书馆管理系统
  • #网络安全#NGSOC与传统SOC的区别
  • GCN+BiLSTM多特征输入时间序列预测(Pytorch)
  • LinkedList和链表之刷题课(下)
  • ollama 在 Linux 环境的安装
  • C语言二刷指针篇
  • LeetCode题练习与总结:回文对--336
  • CesiumJS 案例 P7:添加指定长宽的图片图层(原点分别为图片图层的中心点、左上角顶点、右上角顶点、左下角顶点、右下角顶点)
  • Redis 主从同步 问题
  • 【SQL Server】探讨 IN 和 EXISTS之间的区别
  • 清理pip和conda缓存
  • git rebase和merge的区别
  • 【elkb】linux麒麟v10安装ELKB 8.8.X版本(ARM架构)
  • bluez hid host介绍,连接键盘/鼠标/手柄不是梦,安排
  • GPT打数模——电商品类货量预测及品类分仓规划
  • 华为OD机试 - 螺旋数字矩阵 - 矩阵(Python/JS/C/C++ 2024 D卷 100分)