当前位置: 首页 > news >正文

MMSeg绘制模型指定层的Heatmap热力图

文章首发及后续更新:https://mwhls.top/4475.html,无图/无目录/格式错误/更多相关请至首发页查看。
新的更新内容请到mwhls.top查看。
欢迎提出任何疑问及批评,非常感谢!

摘要:绘制模型指定层的热力图

可视化环境安装

  • 可用的环境版本:
    • mmseg 1.0.0rc5
    • mmdet 3.0.0rc6
    • mmcv 2.0.0rc4
    • mmengine 0.6.0
    • 注:不要用在其它版本跑的文件覆盖它,我最开始一直没成功就是因为我想偷懒直接复制我的模型过去,但是模型调用了在原版本存在,但新版本不存在的方法,导致一直报错。
  • 安装以上环境,参考该 issue 代码可正常推理,代码如下
    • 还有其它 issue 也提到了 featmap,可以在 mmseg 的 GitHub 搜 cam 关键词,或者点这里。
import torch
import cv2
import numpy as npfrom mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnormconfig_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'register_all_modules()model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)cv2.imshow('cam', out)
cv2.waitKey(0)

指定位置可视化

  • 修改后的可视化代码 Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"img_path = prefix + r"img.png"def draw_heatmap(featmap):vis = SegLocalVisualizer()ori_img = cv2.imread(img_path)out = vis.draw_featmap(featmap, ori_img)cv2.imshow('cam', out)cv2.waitKey(0)def generate_featmap(config, checkpoint, img_path):register_all_modules()model = init_model(config, checkpoint, device='cpu')model = revert_sync_batchnorm(model)vis = SegLocalVisualizer()ori_img = cv2.imread(img_path)img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)logits = model(img)out = vis.draw_featmap(logits[0], ori_img)cv2.imshow('cam', out)cv2.waitKey(0)if __name__ == "__main__":generate_featmap(config, checkpoint, img_path)
  • 如下,在模型内调用 draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):"""Forward function."""from Startup import draw_heatmapdraw_heatmap(x[0])if self.deep_stem:x = self.stem(x)else:x = self.conv1(x)x = self.norm1(x)x = self.relu(x)x = self.maxpool(x)outs = []for i, layer_name in enumerate(self.res_layers):res_layer = getattr(self, layer_name)x = res_layer(x)if i in self.out_indices:outs.append(x)from Startup import draw_heatmapdraw_heatmap(x[0])return tuple(outs)

效果展示

Heatmap1.png Heatmap2.png Heatmap3.png Heatmap4.png Heatmap5.png Heatmap6.png
http://www.lryc.cn/news/30601.html

相关文章:

  • 关于Paul C.R. - Inductance_ Loop and partial-Wiley (2009)一书的概括
  • 基于支持向量机SVM的面部表情分类预测
  • java内存模型的理解
  • 自己写一个简单的IOC
  • 用Python批量重命名文件
  • iis之web服务器搭建、部署(详细)~千锋
  • javascript的ajax
  • SpringBoot入门 - 开发中还有哪些常用注解
  • 网络基础(三)
  • Go语言函数高级篇
  • ubuntu16.04 python代码自启动和可执行文件自启动
  • 应用层协议 HTTP HTTPS
  • 图神经网络 pytorch GCN torch_geometric KarateClub 数据集
  • 【博学谷学习记录】超强总结,用心分享丨人工智能 自然语言处理 文本特征处理小结
  • 2023年中职网络安全竞赛解析——隐藏信息探索
  • 实用操作--迁移到Spring Boot 3 和 Spring 6 需要关注的JAVA新特性
  • 等保检测风险处理方案
  • java 包装类 万字详解(通俗易懂)
  • 为什么我复制的中文url粘贴出来会是乱码的? 浏览器url编码和解码
  • 移动端适配
  • 【FPGA】Verilog:时序电路应用 | 序列发生器 | 序列检测器
  • Biomod2 (下):物种分布模型建模
  • Linux性能学习(2.2):内存_进程线程内存分配机制探究
  • BPMN2.0规范及流程引擎选型方案
  • VMware虚拟机安装Linux教程
  • 多人协作|RecyclerView列表模块新架构设计
  • SpringBoot (六) 整合配置文件 @Value、ConfigurationProperties
  • docker 入门篇
  • MapReduce的shuffle过程详解
  • 【软件使用】MarkText下载安装与汉化设置 (markdown快捷键收藏)