当前位置: 首页 > news >正文

DAY 42 Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例
# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):# 模块:当前应用钩子的模块# grad_input:模块输入的梯度# grad_output:模块输出的梯度print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
# 选择一个随机图像
# idx = np.random.randint(len(testset))
idx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()# print("Grad-CAM可视化完成。已保存为grad_cam_result.png")

http://www.lryc.cn/news/621389.html

相关文章:

  • 一文讲透Go语言并发模型
  • PHP现代化全栈开发:实时应用与WebSockets实践
  • PIDGenRc函数中lpstrRpc的由来和InitializePidVariables函数的关系
  • 技术速递|通过 GitHub Models 在 Actions 中实现项目自动化
  • 状态管理、网络句柄、功能组和功能组状态的逻辑关系
  • 提升工作效率的利器:GitHub Actions Checkout V5
  • 【力扣56】合并区间
  • Linux软件编程(四)多任务与多进程管理
  • CMake进阶: externalproject_add用于在构建阶段下载、配置、构建和安装外部项目
  • Google Gemini 的深度研究终于进入 API 阶段
  • 入门概述(面试常问)
  • CodeTop 复习
  • C#WPF实战出真汁01--项目介绍
  • C++入门自学Day11-- List类型的自实现
  • Claude Code频繁出错怎么办?深入架构层面的故障排除指南
  • 力扣-5.最长回文子串
  • Python3 详解:从基础到进阶的完整指南
  • RS232串行线是什么?
  • 机器学习-支持向量机器(SVM)
  • 机器学习——TF-IDF算法
  • 2025天府杯数学建模A题分析
  • Docker存储卷备份策略于VPS服务器环境的实施标准与恢复测试
  • 【ai写代码】lua-判断表是否被修改
  • 【JDK】Linux 系统下 JDK 安装与环境变量配置全教程
  • Auto-Coder的CLI 和 Python API
  • TOTP算法与HOTP算法
  • 下标访问操作符 [] 与函数调用操作符 ()
  • 【软考中级网络工程师】知识点之常用网络诊断和配置命令
  • Qt---Qt函数库
  • 深度学习-卷积神经网络CNN-膨胀卷积、可分离卷积(空间可分离、深度可分离)、分组卷积