当前位置: 首页 > news >正文

Xception模型详解

简介

Xception的名称源自于"Extreme Inception",它是在Inception架构的基础上进行了扩展和改进。Inception架构是Google团队提出的一种经典的卷积神经网络架构,用于解决深度卷积神经网络中的计算和参数增长问题。

与Inception不同,Xception的主要创新在于使用了深度可分离卷积(Depthwise Separable Convolution)来替代传统的卷积操作。深度可分离卷积将卷积操作分解为两个步骤:深度卷积和逐点卷积。

深度卷积是一种在每个输入通道上分别应用卷积核的操作,它可以有效地减少计算量和参数数量。逐点卷积是一种使用1x1卷积核进行通道间的线性组合的操作,用于增加模型的表示能力。通过使用深度可分离卷积,Xception网络能够更加有效地学习特征表示,并在相同计算复杂度下获得更好的性能。

Xception 网络结构

一个标准的Inception模块(Inception V3)

简化后的Inception模块

简化后的Inception的等价结构

采用深度可分离卷积的思想,使 3×3 卷积的数量与 1×1卷积输出通道的数量相等

Xception模型,一共可以分为3个flow,分别是Entry flow、Middle flow、Exit flow。

在这里 Entry 与 Exit 都具有相同的部分,Middle 与这二者有所不同。

Xception模型的pytorch复现

(1)深度可分离卷积

class SeparableConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):super(SeparableConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding,dilation, groups=in_channels, bias=bias)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0,dilation=1, groups=1, bias=False)def forward(self, x):x = self.conv(x)x = self.pointwise(x)return x

(2)构建三个flow结构

class EntryFlow(nn.Module):def __init__(self):super(EntryFlow, self).__init__()self.headconv = nn.Sequential(nn.Conv2d(3, 32, 3, 2, bias=False),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 64, 3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.residual_block1 = nn.Sequential(SeparableConv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),SeparableConv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.MaxPool2d(3, stride=2, padding=1),)self.residual_block2 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),SeparableConv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.MaxPool2d(3, stride=2, padding=1))self.residual_block3 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(256, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.MaxPool2d(3, stride=2, padding=1))def shortcut(self, inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 2, bias=False),nn.BatchNorm2d(oup))def forward(self, x):x = self.headconv(x)residual = self.residual_block1(x)shortcut_block1 = self.shortcut(64, 128)x = residual + shortcut_block1(x)residual = self.residual_block2(x)shortcut_block2 = self.shortcut(128, 256)x = residual + shortcut_block2(x)residual = self.residual_block3(x)shortcut_block3 = self.shortcut(256, 728)x = residual + shortcut_block3(x)return xclass MiddleFlow(nn.Module):def __init__(self):super(MiddleFlow, self).__init__()self.shortcut = nn.Sequential()self.conv1 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728))def forward(self, x):residual = self.conv1(x)input = self.shortcut(x)return input + residualclass ExitFlow(nn.Module):def __init__(self):super(ExitFlow, self).__init__()self.residual_with_exit = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 1024, 3, padding=1),nn.BatchNorm2d(1024),nn.MaxPool2d(3, stride=2, padding=1))self.endconv = nn.Sequential(SeparableConv2d(1024, 1536, 3, 1, 1),nn.BatchNorm2d(1536),nn.ReLU(inplace=True),SeparableConv2d(1536, 2048, 3, 1, 1),nn.BatchNorm2d(2048),nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1)),)def shortcut(self, inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 2, bias=False),nn.BatchNorm2d(oup))def forward(self, x):residual = self.residual_with_exit(x)shortcut_block = self.shortcut(728, 1024)output = residual + shortcut_block(x)return self.endconv(output)

(3)构建网络(完整代码)

"""
Copyright (c) 2023, Auorui.
All rights reserved.Xception: Deep Learning with Depthwise Separable Convolutions<https://arxiv.org/pdf/1610.02357.pdf>
"""
import torch
import torch.nn as nnclass SeparableConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):super(SeparableConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding,dilation, groups=in_channels, bias=bias)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0,dilation=1, groups=1, bias=False)def forward(self, x):x = self.conv(x)x = self.pointwise(x)return xclass EntryFlow(nn.Module):def __init__(self):super(EntryFlow, self).__init__()self.headconv = nn.Sequential(nn.Conv2d(3, 32, 3, 2, bias=False),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 64, 3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),)self.residual_block1 = nn.Sequential(SeparableConv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),SeparableConv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.MaxPool2d(3, stride=2, padding=1),)self.residual_block2 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),SeparableConv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.MaxPool2d(3, stride=2, padding=1))self.residual_block3 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(256, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.MaxPool2d(3, stride=2, padding=1))def shortcut(self, inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 2, bias=False),nn.BatchNorm2d(oup))def forward(self, x):x = self.headconv(x)residual = self.residual_block1(x)shortcut_block1 = self.shortcut(64, 128)x = residual + shortcut_block1(x)residual = self.residual_block2(x)shortcut_block2 = self.shortcut(128, 256)x = residual + shortcut_block2(x)residual = self.residual_block3(x)shortcut_block3 = self.shortcut(256, 728)x = residual + shortcut_block3(x)return xclass MiddleFlow(nn.Module):def __init__(self):super(MiddleFlow, self).__init__()self.shortcut = nn.Sequential()self.conv1 = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728))def forward(self, x):residual = self.conv1(x)input = self.shortcut(x)return input + residualclass ExitFlow(nn.Module):def __init__(self):super(ExitFlow, self).__init__()self.residual_with_exit = nn.Sequential(nn.ReLU(inplace=True),SeparableConv2d(728, 728, 3, padding=1),nn.BatchNorm2d(728),nn.ReLU(inplace=True),SeparableConv2d(728, 1024, 3, padding=1),nn.BatchNorm2d(1024),nn.MaxPool2d(3, stride=2, padding=1))self.endconv = nn.Sequential(SeparableConv2d(1024, 1536, 3, 1, 1),nn.BatchNorm2d(1536),nn.ReLU(inplace=True),SeparableConv2d(1536, 2048, 3, 1, 1),nn.BatchNorm2d(2048),nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1)),)def shortcut(self, inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 2, bias=False),nn.BatchNorm2d(oup))def forward(self, x):residual = self.residual_with_exit(x)shortcut_block = self.shortcut(728, 1024)output = residual + shortcut_block(x)return self.endconv(output)class Xception(nn.Module):def __init__(self, num_classes=1000):super().__init__()self.num_classes = num_classesself.entry_flow = EntryFlow()self.middle_flow = MiddleFlow()self.exit_flow = ExitFlow()self.fc = nn.Linear(2048, num_classes)def forward(self, x):x = self.entry_flow(x)for i in range(8):x = self.middle_flow(x)x = self.exit_flow(x)x = x.view(x.size(0), -1)out = self.fc(x)return outif __name__=='__main__':import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = Xception(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Xception Total params: 19,838,076

参考文章

【精读AI论文】Xception ------(Xception: Deep Learning with Depthwise Separable Convolutions)_xception论文-CSDN博客

[ 轻量级网络 ] 经典网络模型4——Xception 详解与复现-CSDN博客

神经网络学习小记录22——Xception模型的复现详解_xception timm-CSDN博客

【卷积神经网络系列】十七、Xception_xception模块-CSDN博客 

http://www.lryc.cn/news/329612.html

相关文章:

  • 【合合TextIn】AI构建新质生产力,合合信息Embedding模型助力专业知识应用
  • Flutter 拦截系统键盘,显示自定义键盘
  • 内存泄漏是什么?如何避免内存泄漏?
  • linux 中的syslog的含义和用法
  • kubernetes(K8S)学习(一):K8S集群搭建(1 master 2 worker)
  • 巧克力(蓝桥杯)
  • Python爬虫之pyquery和parsel的使用
  • 移动硬盘怎么加密?移动硬盘加密软件有哪些?
  • openEuler 22.03 安装 .NET 8.0
  • 【转载】OpenCV ECC图像对齐实现与代码演示(Python / C++源码)
  • 每日一题(相交链表 )
  • C#WPF控件大全
  • 好书推荐 《AIGC重塑金融》
  • 【Linux】权限理解
  • 插入排序、归并排序、堆排序和快速排序的稳定性分析
  • 【pytest、playwright】多账号同时操作
  • 软考 系统架构设计师系列知识点之云原生架构设计理论与实践(8)
  • 【C++】stack、queue和优先级队列
  • 第十三届蓝桥杯国赛真题 Java C 组【原卷】
  • docker部署ubuntu
  • iOS问题记录 - App Store审核新政策:隐私清单 SDK签名(持续更新)
  • ES学习日记(二)-------集群设置
  • 农村集中式生活污水分质处理及循环利用技术指南
  • linux 一些命令
  • 移动硬盘损坏打不开?别急,这里有解决方案!
  • 微信小程序【从入门到精通】——服务器的数据交互
  • Python爬虫-懂车帝城市销量榜单
  • 《QDebug 2024年3月》
  • C# OpenCvSharp-HoughCircles(霍夫圆检测) 简单计数
  • MybatisPlus速成