DenseNet详解,附模型代码(pytorch)
文章目录
- 前言
- DenseNet的提出
- 与ResNet的对比
- add
- concat
- 网络结构
- DenseLayer
- DenseBlock
- Transition
- Densenet
- DenseNet121\161\169\201\264
- 参考资料
前言
在当时的计算机视觉领域,从LeNet开始,卷积神经网络逐步开始成为最主流的方法。像AlexNet,VGG,GoogLeNet等等,大家都始终没有放弃去寻找一个最优的网络架构。尤其是当ResNet的出现,其成为了深度学习方向最主要的网络结构之一。 因为ResNet 可以训练出更深的 CNN 模型,其让深度学习也成为了可能,走向深度,从而实现更高的准确度。它的核心在于层与层之间的短路连接 (skip connection), skip connection 有助于训练过程中的梯度的反向传播,一定程度上减缓因为梯度消失导致网络训练不动,甚至效果下降的情况。关于ResNet的文章,我也讲解过,大家感兴趣的可以去看看。
DenseNet的提出
今天我们将要介绍DenseNet,其同样也是为了去探索最优的网络架构。
DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。
为什么建立前后层的连接很重要,尤其是在深层网络中?
**就是我们希望网络通过一系列非线性变换,抽象提取出其所蕴含的深层语义信息,其是符合我们所学习的数据的分布的,但是当网络深了之后,其提取的信息是有偏移的,即不是原来数据的分布。因为每一次的非线性变换其是会损失数据信息的,每一层的偏移会导致最后所学习到数据分布是有问题的,从而导致模型的表现不佳。**以下这图相信能更好的帮助理解。
与ResNet的对比
DenseNet和 ResNet、Inception 网络不同的是,DenseNet 并没有主要从网络的深度和宽度入手,DenseNet 的作者从 feature 入手,通过对 feature 的细腻操控,达到了更好的效果和更少的参数。DenseNet 的中心思想是与其多次学习冗余的特征,特征复用,是一种更好的特征提取方式
ResNet: 通过建立前面层和后面层的短路连接(skip connection), 帮助实现训练过程中更有效的反向传播,训练出更深的 CNN 网络;
DenseNet: 采用了比ResNet更极端的方法,通过密集连接机制,互相连接所有的层,每个层会将前面所有层的输出,在 channel 维度上进行 concat 操作,作为当前层的输入,进而实现特征重用。使用过该种方法不仅仅缓解了梯度消失的现象,也使得其在参数和计算量更少的情况下实现比 ResNet 更优的性能;
同时这里大家注意两者的Skip connection的方式是不同的,ResNet的是add的方式,而DenseNet是concat的方式。
那么二者有什么区别呢?
add
我们来看,以下是 keras 中对 add 的实现源码,pytorch的封装更复杂一些,不过原理都是一样的,看这个就行:
def _merge_function(self, inputs):output = inputs[0]for i in range(1, len(inputs)):output += inputs[i]return output
其中 inputs 为待融合的特征图,inputs[0]、inputs[1]……等的通道数一样,且特征图宽与高也一样。
从代码中可以很容易地看出,add 方式有以下特点:
- 做的是对应通道对应位置的值的相加,通道数不变
- 描述图像的特征个数不变,但是每个特征下的信息却增加了。
concat
同样的,我们通过阅读下面代码实例帮助理解 concat 的工作原理:
import torch# 创建两个张量
t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = torch.tensor([[7, 8, 9], [10, 11, 12]])# 沿第1维拼接
result_1 = torch.cat([t1, t2], dim=1)
print(result_1)
# 输出: tensor([[ 1, 2, 3, 7, 8, 9],
# [ 4, 5, 6, 10, 11, 12]])
在模型网路当中,数据通常为 4 个维度,即 num×channels×height×width ,因此默认值 1 表示的是 channels 通道进行拼接。如:
combine = torch.cat([d1, add1, add2, add3, add4], 1)
从代码中可以很容易地看出,concat 方式有以下特点:
- 做的是通道的合并,通道数变多了
- 描述图像的特征个数变多,但是每个特征下的信息却不变。
所以到这里,我们就能够很清晰的知道add操作和concat操作的不同了。
操作 | 描述 | 优点 | 缺点 | 补充 |
---|---|---|---|---|
add | - 相当于加了一种prior - 要求两路输入的对应通道特征图语义类似 | - 计算量少 | - 特征提取能力差 | - 对应通道信息类似时,可融合多通道信息 - 尺度不一致时,小尺度特征可能被淹没 |
concat | - 通过训练学习整合两个特征图通道之间的信息 | - 特征提取能力强 | - 计算量大(是add的2倍) | - 能提取更合适的信息,效果更好 |
网络结构
其实Densenet的构建我们主要就是从三个方面入手的,首先我们要构建Densenet,其实也类似于Resnet的搭建,分块进行构建。我们要搭建Densenet,就要构建DenseBlock,然后要搭建DenseBlock,就要构建好里面的每个DenseLayer,最后我们需要连接不同的DenseBlock,又需要Transition部分。最后搭建Densenet,就将不同的部分按照一定的顺序连接起来即可。其实就是如下图所示的。
DenseLayer
DenseLayer的搭建还是很简单的,就是BN-ReLU-Conv三件套连着来两次。这里有个细节要注意的就是由于越到后面输入会越大,这里Densenet为了减少计算量,在第一个Conv的时候使用1x1卷积调整通道数到bnsize∗growthratebn_{size}*growth_{rate}bnsize∗growthrate,一般bn_size设置为4。从而能够降低特征数量,提升计算效率。
然后最后forward函数里面我们最后输出的是torch.cat([x,new_feature],1),将我们新生成的feature加入到输入中,然后作为下一个DenseLayer的输入,这样就实现了dense connection的思想。
class _DenseLayer(nn.Module):def __init__(self,num_input_features,growth_rate,bn_size,drop_rate):super(_DenseLayer, self).__init__()self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features,bn_size*growth_rate,kernel_size=1,stride=1,padding=0,bias=False)self.norm2 = nn.BatchNorm2d(bn_size*growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size*growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False)self.drop_rate = drop_ratedef forward(self,x):new_feature = self.norm1(x)new_feature = self.relu1(new_feature)new_feature = self.conv1(new_feature)new_feature = self.norm2(new_feature)new_feature = self.relu2(new_feature)new_feature = self.conv2(new_feature)if self.drop_rate > 0:new_feature = F.dropout(new_feature, p=self.drop_rate, training=self.training)return torch.cat([x,new_feature],1)
DenseBlock
主要就是看每个block里面有多少个layer嘛,主要每个layer的输入channel就行了,是递增的,根据增长率growth_rate。
class _DenseBlock(nn.ModuleDict):_version = 2def __init__(self,num_layers,num_input_features,growth_rate,bn_size,drop_rate):super(_DenseBlock, self).__init__()for i in range(num_layers):layer = _DenseLayer(num_input_features + i * growth_rate,growth_rate,bn_size,drop_rate)self.add_module('denselayer%d'%(i+1),layer)def forward(self,features):for name,layer in self.items():features = layer(features)return features
Transition
对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition层可以起到压缩模型的作用。
这里我们按照论文中所讲的压缩模型,就是AvgPool2d的设置 ,论文中讲述的压缩系数0.5,即设置2x2 AvgPooling即可。减少shape,起到下采样的作用。
class _Transition(nn.Sequential):def __init__(self,num_input_features,num_output_features):super(_Transition, self).__init__()self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features,num_output_features,kernel_size=1,stride=1,padding=0,bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
Densenet
这部分就比较简单了,对于Densenet,首先经过一个初步的特征编码,7x7的卷积,BN层,激活层,池化层,然后就是四个DenseBlock层的搭建,并且每个DenseBlock使用Transition层进行连接,最后就是BN和全连接层了。这是整体的架构的搭建,在来看一些细节的处理,对于我们所搭建的网络模型,我们肯定是需要进行参数的初始化的,对于卷积层的参数采用凯明初始化,BN层的参数初始化为权重为1,偏置为0。
class DenseNet(nn.Module):def __init__(self,growth_rate=32,block_config=(6,12,24,16),num_init_features=64,bn_size=4,drop_rate=0,num_classes=1000):super(DenseNet, self).__init__()self.features = nn.Sequential(OrderedDict([("conv0",nn.Conv2d(3,num_init_features,kernel_size=7,stride=2,padding=3,bias=False)),("norm0",nn.BatchNorm2d(num_init_features)),("relu0",nn.ReLU(inplace=True)),("pool0",nn.MaxPool2d(kernel_size=3, stride=2)),]))num_features = num_init_featuresfor i,num_layers in enumerate(block_config):block = _DenseBlock(num_layers,num_features,growth_rate,bn_size,drop_rate)self.features.add_module('denseblock%d'%(i+1),block)num_features = num_features + num_layers * growth_rateif i != len(block_config)-1:transition = _Transition(num_features,num_features // 2)self.features.add_module('transition%d'%(i+1),transition)num_features = num_features // 2self.features.add_module('norm5',nn.BatchNorm2d(num_features))self.classifier = nn.Linear(num_features,num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.relu(features,inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return out
DenseNet121\161\169\201\264
然后不同深度的DenseNet网络,就是通过控制其是使用的block每层具体的layer数量来控制,所以我们可以搭建多个不同深度的ResNet模型。
def densenet121(pretrained=True,**kwargs):model = DenseNet(growth_rate=32,block_config=(6,12,24,16),**kwargs)if pretrained:pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')state_dict = model_zoo.load_url(model_urls['densenet121'])for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]model.load_state_dict(state_dict)return modeldef densenet161(pretrained=True,**kwargs):model = DenseNet(growth_rate=48,block_config=(6, 12, 36, 24),**kwargs)if pretrained:pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')state_dict = model_zoo.load_url(model_urls['densenet161'])for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]model.load_state_dict(state_dict)return modeldef densenet169(pretrained=True,**kwargs):model = DenseNet(growth_rate=32,block_config=(6,12,32,32),**kwargs)if pretrained:pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')state_dict = model_zoo.load_url(model_urls['densenet169'])for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]model.load_state_dict(state_dict)return modeldef densenet201(pretrained=True,**kwargs):model = DenseNet(growth_rate=32,block_config=(6,12,48,32),**kwargs)if pretrained:pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')state_dict = model_zoo.load_url(model_urls['densenet201'])for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]model.load_state_dict(state_dict)return modeldef densenet264(**kwargs):model = DenseNet(growth_rate=32,block_config=(6,12,64,48),**kwargs)return model
参考资料
希望这篇文章能够给大家带来些思考,让大家能够有所收获,同时以上内容有所参考下列文章进行学习,并包含了我自己的思考。后续将会给大家带来使用DensenNet解决相关问题的实际应用部署。
DensetNet 介绍 - lucky_light - 博客园