MMRotate ReDet ReFPN 报错 `assert input.type == self.in_type`
在跑实验时,使用 configs/redet/redet_re50_refpn_1x_dota_le90.py
,结果报错:
Traceback (most recent call last):File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 196, in <module>main()File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 183, in maintrain_detector(File "h:\workspace\deeplearning\mmrotate\mmrotate\apis\train.py", line 145, in train_detectorrunner.run(data_loaders, cfg.workflow)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in runepoch_runner(data_loaders[i], **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 53, in trainself.run_iter(data_batch, train_mode=True, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 31, in run_iteroutputs = self.model.train_step(data_batch, self.optimizer,File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\parallel\data_parallel.py", line 77, in train_stepreturn self.module.train_step(*inputs[0], **kwargs[0])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 248, in train_steplosses = self(**data)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 172, in forwardreturn self.forward_train(img, img_metas, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 127, in forward_trainx = self.extract_feat(img)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 69, in extract_featx = self.neck(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 298, in forwardlaterals = [File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 299, in <listcomp>self.lateral_convs[i](inputs[i + self.start_level])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 148, in forwardx = self.conv(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\e2cnn\nn\modules\r2_conv\r2convolution.py", line 326, in forwardassert input.type == self.in_type
AssertionError
按照以下提示修改 mmrotate/models/necks/re_fpn.py
三处地方,其余地方不变。
# 1. 引入 build_enn_divide_feature 函数
from ..utils import (build_enn_divide_feature,build_enn_feature, build_enn_norm_layer, ennConv,ennInterpolate, ennMaxPool, ennReLU
)class ConvModule(enn.EquivariantModule):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias='auto',conv_cfg=None,norm_cfg=None,activation='relu',inplace=False,order=('conv', 'norm', 'act')):super(ConvModule, self).__init__()assert conv_cfg is None or isinstance(conv_cfg, dict)assert norm_cfg is None or isinstance(norm_cfg, dict)# 2. 用 build_enn_divide_feature 替换 build_enn_featureself.in_type = build_enn_divide_feature(in_channels)self.out_type = build_enn_divide_feature(out_channels)# 后续保持不变...def forward(self, x, activate=True, norm=True):"""Forward function of ConvModule."""# 3. 如果传入的是普通 Tensor,则封装为 GeometricTensorif isinstance(x, torch.Tensor):x = enn.GeometricTensor(x, self.in_type)for layer in self.order:if layer == 'conv':x = self.conv(x)elif layer == 'norm' and norm and self.with_norm:x = self.norm(x)elif layer == 'act' and activate and self.with_activatation:x = self.activate(x)return x