当前位置: 首页 > news >正文

MMRotate ReDet ReFPN 报错 `assert input.type == self.in_type`

在跑实验时,使用 configs/redet/redet_re50_refpn_1x_dota_le90.py,结果报错:

Traceback (most recent call last):File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 196, in <module>main()File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 183, in maintrain_detector(File "h:\workspace\deeplearning\mmrotate\mmrotate\apis\train.py", line 145, in train_detectorrunner.run(data_loaders, cfg.workflow)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in runepoch_runner(data_loaders[i], **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 53, in trainself.run_iter(data_batch, train_mode=True, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 31, in run_iteroutputs = self.model.train_step(data_batch, self.optimizer,File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\parallel\data_parallel.py", line 77, in train_stepreturn self.module.train_step(*inputs[0], **kwargs[0])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 248, in train_steplosses = self(**data)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 172, in forwardreturn self.forward_train(img, img_metas, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 127, in forward_trainx = self.extract_feat(img)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 69, in extract_featx = self.neck(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 298, in forwardlaterals = [File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 299, in <listcomp>self.lateral_convs[i](inputs[i + self.start_level])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 148, in forwardx = self.conv(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\e2cnn\nn\modules\r2_conv\r2convolution.py", line 326, in forwardassert input.type == self.in_type
AssertionError

按照以下提示修改 mmrotate/models/necks/re_fpn.py 三处地方,其余地方不变。

# 1. 引入 build_enn_divide_feature 函数
from ..utils import (build_enn_divide_feature,build_enn_feature, build_enn_norm_layer, ennConv,ennInterpolate, ennMaxPool, ennReLU
)class ConvModule(enn.EquivariantModule):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias='auto',conv_cfg=None,norm_cfg=None,activation='relu',inplace=False,order=('conv', 'norm', 'act')):super(ConvModule, self).__init__()assert conv_cfg is None or isinstance(conv_cfg, dict)assert norm_cfg is None or isinstance(norm_cfg, dict)# 2. 用 build_enn_divide_feature 替换 build_enn_featureself.in_type = build_enn_divide_feature(in_channels)self.out_type = build_enn_divide_feature(out_channels)# 后续保持不变...def forward(self, x, activate=True, norm=True):"""Forward function of ConvModule."""# 3. 如果传入的是普通 Tensor,则封装为 GeometricTensorif isinstance(x, torch.Tensor):x = enn.GeometricTensor(x, self.in_type)for layer in self.order:if layer == 'conv':x = self.conv(x)elif layer == 'norm' and norm and self.with_norm:x = self.norm(x)elif layer == 'act' and activate and self.with_activatation:x = self.activate(x)return x
http://www.lryc.cn/news/600085.html

相关文章:

  • Linux的磁盘存储管理实操——(下二)——逻辑卷管理LVM的扩容、缩容
  • ComfyUI中运行Wan 2.1工作流,电影级视频,兼容Mac, Windows
  • 一些常见的网络攻击方式
  • 与 TRON (波场) 区块链进行交互的命令行工具 (CLI): tstroncli
  • 关闭chrome自带的跨域限制,简化本地开发
  • 【Chrome】下载chromedriver的地址
  • 中国航天集团实习第一周总结
  • 低速信号设计之 SWD 篇
  • 随机抽签服务API集成指南
  • python学习DAY22打卡
  • 如何评估一个RWA项目的可信度?关键指标解析
  • 图书推荐-由浅入深的大模型构建《从零构建大模型》
  • C语言————原码 补码 反码 (日渐清晰版)
  • openGauss数据库在CentOS 7 中的单机部署与配置
  • 在幸狐RV1106板子上用gcc14.2本地编译安装ssh客户端/服务器、vim编辑器、sl和vsftpd服务器
  • 基础很薄弱如何规划考研
  • 解密负载均衡:如何轻松提升业务性能
  • QT开发---多线程编程
  • 【SpringAI实战】ChatPDF实现RAG知识库
  • XORIndex:朝鲜不断发展的供应链恶意软件再次瞄准 npm 生态系统
  • 从分治的思想下优化快速排序算法
  • 免模型控制
  • 蓝桥杯java算法例题
  • 计算机网络(第八版)— 第2章课后习题参考答案
  • [NLP]多电源域设计的仿真验证方法
  • 数字化转型-AI落地金字塔法则
  • 【日志】unity俄罗斯方块——边界限制检测
  • 深度学习篇---图像数据采集
  • 【VLAs篇】06:从动作词元化视角谈VLA模型的综述
  • JavaSE-图书信息管理系统