当前位置: 首页 > article >正文

python学习打卡day47

DAY 47 注意力热图可视化

昨天代码中注意力热图的部分顺移至今天

知识点回顾:

热力图

作业:对比不同卷积层热图可视化的结果

# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):"""可视化模型的注意力热力图,展示模型关注的图像区域"""model.eval()  # 设置为评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples:  # 只可视化前几个样本breakimages, labels = images.to(device), labels.to(device)# 创建一个钩子,捕获中间特征图activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 为最后一个卷积层注册钩子(获取特征图)hook_handle = model.conv3.register_forward_hook(hook)# 前向传播,触发钩子outputs = model(images)# 移除钩子hook_handle.remove()# 获取预测结果_, predicted = torch.max(outputs, 1)# 获取原始图像img = images[0].cpu().permute(1, 2, 0).numpy()# 反标准化处理img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)# 获取激活图(最后一个卷积层的输出)feature_map = activation_maps[0][0].cpu()  # 取第一个样本# 计算通道注意力权重(使用SE模块的全局平均池化)channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]# 按权重对通道排序sorted_indices = torch.argsort(channel_weights, descending=True)# 创建子图fig, axes = plt.subplots(1, 4, figsize=(16, 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')axes[0].axis('off')# 显示前3个最活跃通道的热力图for j in range(3):channel_idx = sorted_indices[j]# 获取对应通道的特征图channel_map = feature_map[channel_idx].numpy()# 归一化到[0,1]channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)# 调整热力图大小以匹配原始图像from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 显示热力图axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')axes[j+1].axis('off')plt.tight_layout()plt.show()# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

@浙大疏锦行

http://www.lryc.cn/news/2403815.html

相关文章:

  • MySQL中的内置函数
  • Ansible自动化运维全解析:从设计哲学到实战演进
  • YOLOv8n行人检测实战:从数据集准备到模型训练
  • 国标GB28181设备管理软件EasyGBS远程视频监控方案助力高效安全运营
  • 网络寻路--图论
  • LangChain4j 学习教程项目
  • 【Go语言基础【15】】数组:固定长度的连续存储结构
  • 【读论文】U-Net: Convolutional Networks for Biomedical Image Segmentation 卷积神经网络
  • Komiko 视频到视频功能炸裂上线!
  • Linux 文件系统与 I/O 编程核心原理及实践笔记
  • vite+tailwind封装组件库
  • Gin框架实战指南:从入门到进阶
  • 【Java学习笔记】包装类
  • 【高效开发工具系列】Blackmagic Disk Speed Test for Mac:专业硬盘测速工具
  • QtDBus模块功能及架构解析
  • 光学字符识别(OCR)理论概述与实践教程
  • 关键字--sizeof
  • Ubuntu20.04启动python的虚拟环境
  • 网页在线客服系统自动欢迎语实现方案(PHP+MySQL)
  • UniRig:如何在矩池云一站式解决 3D 模型绑定难题
  • 用函数实现模块化程序设计(适合考研、专升本)
  • 玩转抖音矩阵:核心玩法与高效运营规则
  • spring:继承接口FactoryBean获取bean实例
  • 字符串字典序最大后缀问题详解
  • VScode打开后一直显示正在重新激活终端 问题的解决方法
  • pe文件结构(TLS)
  • 二进制安全-OpenWrt-uBus
  • 分页查询的实现
  • 中型零售业数据库抉择:MySQL省成本,SQL SERVER?
  • 使用 Windows 完成 iOS 应用上架:Appuploader对比其他证书与上传方案