当前位置: 首页 > news >正文

STANet代码复现出现的问题

1

IndexError: boolean index did not match indexed array along dimension 0; dimension is 4194304 but corresponding boolean dimension is 65536

定位到导致错误的代码,是metric.py,Collect values for Confusion Matrix 收集混淆矩阵的值时出错
在这里插入图片描述
这一段是为了比较真实值与预测值,但是预测值的维度却比真实值大了64倍。
参考here
问题分析:
简要说就是torch版本更新后,有些函数输出的BCWH变成了BWHC。
运行报错后,一行一行溯源,发现问题出在pred的shape和Label的shape不匹配。
CDF0和CDFA中,forward是对backbone的计算的特征图进行相似度计算,然后这个相似度通过阈值1选择后作为pred的结果的。
以下为猜测,没有找到实料。我猜测老版本torch中F.pairwise_distance生成的结果是BCWH,因此可以直接拿来插值然后和label做比较。但新版本应该是变成了BWHC。用默认的Resnet18(即netF)生成的特征层应该为B*64©64(W)64(H),F.pairwise_distance生成的结果为B64(W)64(H)1©,插值后就变为B64256256,所以导致报错的dimension前面数值总是后面的64倍。
解决办法:
在CDFA和CDF0中,找到forward函数,将F.pairwise_distance生成的结果进行通道和行列变换。

    def forward(self):"""Run forward pass; called by both functions <optimize_parameters> and <test>."""self.feat_A = self.netF(self.A)  # f(A)self.feat_B = self.netF(self.B)   # f(B)# 距离度量self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) # 特征距离 B*W*H*C# 在此添加两个打印,可以输出看一下# print(self.dist.shape)# torch.Size([2, 64, 64, 1])# 在此新增以下代码行self.dist = self.dist.permute(0, 3, 1, 2)  # 需要变换成B*C*W*H# print(self.dist.shape)# torch.Size([2, 1, 64, 64])# print(self.dist.shape)self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)#self.pred_L = (self.dist > 1).float()self.pred_L_show = self.pred_L.long() #将数字或字符串转换为一个长整型return self.pred_L
http://www.lryc.cn/news/66924.html

相关文章:

  • Java 中String对象详解
  • k8s nfs运行问题、etcd问题、calico网络问题
  • Qt--QString字符串类、QTimer定时器类
  • 2023.5.13>>Eclipse+exe4j打包Java项目及获取exe所在文件的路径
  • Centos系统的使用基本教程
  • IDEA生成ER图、UML类图、时序图、流程图等的插件推荐或独立工具推荐
  • Python心经(3)
  • 单工,半双工,全双工通讯
  • 【2023-05-09】 设计模式(单例,工厂)
  • 批量任务导致页面卡死解决方案
  • 避免“文献综抄”,5种写作结构助你完成文献综述→
  • Java异常和反射
  • Accesss数据库的那点事
  • 网络基础学习:osi网络七层模型
  • EndNote X9 引用参考 单击文献编号,不能跳转到文尾文献列表处,咋解决?文献编号 不能跳转 ,怎么办?
  • 用免费蜜罐工具配置Modbus工控蜜罐
  • DataGridXL中快速搜索单元格和底部全屏模式区域隐藏
  • DotNet几种微服务框架,你用过吗?
  • Nature | 生成式人工智能如何构建更好的抗体
  • 【hive】基于Qt5和libuv udp 的lan chat
  • Java版本工程项目管理系统源码,助力工程企业实现数字化管理
  • 什么是零拷贝?
  • 计算机专业含金量高的证书
  • 原装二手Keithley 2401低压源表 吉时利2401数字源表
  • gradle-8.1.1-all 快速下载百度网盘下载
  • C#开发的OpenRA游戏之基地工程车部署命令产生过程
  • C++ 智能指针的原理、分类、使用
  • 学习笔记——SVG.js中形状元素的创建及其相关方法
  • Linux一学就会——系统文件I/O
  • OpenCV-Python图像阈值