当前位置: 首页 > news >正文

CNN卷积神经网络之MobileNet和ResNet(五)

CNN卷积神经网络之MobileNet和ResNet(五)

文章目录

  • CNN卷积神经网络之MobileNet和ResNet(五)
  • MobileNet V1/V2 & ResNet(附可直接运行的 PyTorch 代码)
    • 1. 模型速查表
    • 2. MobileNet V1 核心回顾
    • 3. MobileNet V2 改进点
    • 4. ResNet 核心回顾
    • 5. PyTorch 最小复现代码
      • 5.1 项目结构
      • 5.2 `models.py`
      • 5.3 `main.py`
    • 6. 总结 & 何时选谁?
    • 7. 参考资料


MobileNet V1/V2 & ResNet(附可直接运行的 PyTorch 代码)

本文主要讲的是移动端经典轻量网络 MobileNet V1 / V2 与深度残差网络 ResNet 的核心思想、结构差异、设计细节,并给出可直接 python main.py 运行的最小复现代码。


1. 模型速查表

模型发表关键词参数量*Top-1*
ResNet-502015残差、跳跃连接25.6 M76.0 %
MobileNet V12017深度可分离卷积、α/ρ 压缩4.2 M70.6 %
MobileNet V22018Inverted Residual、Linear Bottleneck3.4 M72.0 %

* 以 ImageNet 1k 官方数据为基准。


2. MobileNet V1 核心回顾

  1. Depthwise Separable Convolution

    • 3×3 DWConvchannel-wise
    • 1×1 PWConvpoint-wise
    • 计算量降低约 8~9×
  2. 两个超参数

    • Width multiplier α 统一缩放通道数
    • Resolution multiplier ρ 统一缩放输入分辨率

3. MobileNet V2 改进点

改进动机具体做法
Inverted Residual解决 V1 信息坍塌1×1 升维3×3 DWConv1×1 降维
Linear Bottleneck防止 ReLU 破坏最后一个 1×1 后不加 ReLU
ReLU6移动端量化友好min(max(x,0),6)
Shortcut复用特征仅当 stride=1in_channels==out_channels 时启用

结构示意:
在这里插入图片描述


4. ResNet 核心回顾

  1. Residual Block

    • 跳跃连接:F(x) + x
    • 解决梯度消失,让 1000+ 层网络可训练。
  2. 两种 Block

    • BasicBlock(小模型:18/34)
    • Bottleneck(大模型:50/101/152)——使用 1×1 降维/升维减少计算量。
  3. 下采样策略

    • stride=2 或通道数改变时,shortcut 路径用 1×1 Conv + stride=2 对齐维度。

5. PyTorch 最小复现代码

环境:Python≥3.8,PyTorch≥1.10
运行:python main.py 默认在 CIFAR-10 上训练 5 个 epoch 做演示。

5.1 项目结构

mbv2_resnet_demo/├─ main.py          # 训练&验证├─ models.py        # 三种网络定义└─ utils.py         # 通用工具

5.2 models.py

import torch.nn as nn
import torch.nn.functional as F# ----------- MobileNet V1 -----------
class DepthwiseSeparable(nn.Module):def __init__(self, in_c, out_c, stride=1):super().__init__()self.depthwise = nn.Conv2d(in_c, in_c, 3, stride, 1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.pointwise = nn.Conv2d(in_c, out_c, 1, 1, 0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)def forward(self, x):x = F.relu(self.bn1(self.depthwise(x)))x = F.relu(self.bn2(self.pointwise(x)))return xclass MobileNetV1(nn.Module):cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]def __init__(self, num_classes=10):super().__init__()self.conv1 = nn.Conv2d(3, 32, 3, 1, 1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.layers = self._make_layers(in_planes=32)self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(1024, num_classes)def _make_layers(self, in_planes):layers = []for c in self.cfg:out_planes = c if isinstance(c, int) else c[0]stride = 1 if isinstance(c, int) else c[1]layers.append(DepthwiseSeparable(in_planes, out_planes, stride))in_planes = out_planesreturn nn.Sequential(*layers)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.layers(x)x = self.pool(x).flatten(1)return self.fc(x)# ----------- MobileNet V2 -----------
class InvertedResidual(nn.Module):def __init__(self, in_c, out_c, stride, expand_ratio=6):super().__init__()hidden = int(round(in_c * expand_ratio))self.use_res_connect = stride == 1 and in_c == out_clayers = []if expand_ratio != 1:layers.append(nn.Conv2d(in_c, hidden, 1, 1, 0, bias=False))layers.append(nn.BatchNorm2d(hidden))layers.append(nn.ReLU6(inplace=True))layers.extend([nn.Conv2d(hidden, hidden, 3, stride, 1, groups=hidden, bias=False),nn.BatchNorm2d(hidden),nn.ReLU6(inplace=True),nn.Conv2d(hidden, out_c, 1, 1, 0, bias=False),nn.BatchNorm2d(out_c),])self.conv = nn.Sequential(*layers)def forward(self, x):if self.use_res_connect:return x + self.conv(x)return self.conv(x)class MobileNetV2(nn.Module):cfg = [(1, 16, 1, 1), (6, 24, 2, 1), (6, 32, 3, 2),(6, 64, 4, 2), (6, 96, 3, 1), (6, 160, 3, 2), (6, 320, 1, 1)]def __init__(self, num_classes=10):super().__init__()self.conv1 = nn.Conv2d(3, 32, 3, 1, 1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.layers = self._make_layers()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(320, num_classes)def _make_layers(self):layers = []in_c = 32for t, c, n, s in self.cfg:for i in range(n):stride = s if i == 0 else 1layers.append(InvertedResidual(in_c, c, stride, t))in_c = creturn nn.Sequential(*layers)def forward(self, x):x = F.relu6(self.bn1(self.conv1(x)))x = self.layers(x)x = self.pool(x).flatten(1)return self.fc(x)# ----------- ResNet-18 -----------
class BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1):super().__init__()self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.shortcut = nn.Sequential()if stride != 1 or in_planes != self.expansion * planes:self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes, 1, stride, bias=False),nn.BatchNorm2d(self.expansion * planes))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)return F.relu(out)class ResNet18(nn.Module):def __init__(self, num_classes=10):super().__init__()self.in_planes = 64self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.layer1 = self._make_layer(64, 2, 1)self.layer2 = self._make_layer(128, 2, 2)self.layer3 = self._make_layer(256, 2, 2)self.layer4 = self._make_layer(512, 2, 2)self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(512, num_classes)def _make_layer(self, planes, blocks, stride):layers = [BasicBlock(self.in_planes, planes, stride)]self.in_planes = planes * BasicBlock.expansionfor _ in range(1, blocks):layers.append(BasicBlock(self.in_planes, planes, 1))return nn.Sequential(*layers)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.layer4(self.layer3(self.layer2(self.layer1(x))))x = self.pool(x).flatten(1)return self.fc(x)

5.3 main.py

import torch, torchvision, time
from models import MobileNetV1, MobileNetV2, ResNet18device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
epochs = 5
lr = 1e-3train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(root='./data', train=True, download=True,transform=torchvision.transforms.ToTensor()),batch_size=batch_size, shuffle=True, num_workers=4)test_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(root='./data', train=False,transform=torchvision.transforms.ToTensor()),batch_size=batch_size, shuffle=False, num_workers=4)def train(model):model.to(device)opt = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(epochs):model.train()tic = time.time()for x, y in train_loader:x, y = x.to(device), y.to(device)opt.zero_grad()loss_fn(model(x), y).backward()opt.step()print(f'Epoch {epoch+1} done in {time.time()-tic:.1f}s')if __name__ == '__main__':net = MobileNetV2()  # 可换成 ResNet18() / MobileNetV1()print(net)train(net)

6. 总结 & 何时选谁?

场景首选模型理由
服务器高精度ResNet-50/101精度高、训练稳定
手机端实时MobileNetV2参数少 + 高帧率
超轻量级 MCUMobileNetV1+α=0.25极致压缩,量化后 < 1 M

7. 参考资料

  • MobileNetV1 paper: arXiv:1704.04861
  • MobileNetV2 paper: arXiv:1801.04381
  • ResNet paper: arXiv:1512.03385

码字不易,欢迎点赞收藏转发~

http://www.lryc.cn/news/610001.html

相关文章:

  • AWS Lambda Function 全解:无服务器计算
  • CAD格式转换器HOOPS Exchange:全方位支持HOOPS系列产品
  • Webpack 搭建 Vue3 脚手架详细步骤
  • Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现人脸面部表情的追踪识别(C#代码UI界面版)
  • [3D数据存储] Archive (File Container) | 创建/写入/读取 | 存储格式HDF5
  • pyqt5-tools/pyqt6-tools 安装失败,解决办法
  • app-1
  • Spring P1 | 创建你的第一个Spring MVC项目(IDEA图文详解版,社区版专业版都有~)
  • 理解 Agent 的基本概念与功能
  • 正点原子STM32MP257开发板移植ubuntu24.04根文件系统(带桌面版)
  • RTSP/RTMP播放器超低延迟实战:无人机远控视觉链路的工程实践
  • [特殊字符]️ 整个键盘控制无人机系统框架
  • 链表与数组面试常见问题详解与实现
  • 分布式存储性能跃迁指南:RoCE无损网络设计与优化
  • mysql远程登陆失败
  • DC-Mamba:一种用于热红外无人机图像盲超分辨率的退化感知跨模态框架
  • 正则表达式在js中的应用
  • Hadoop MapReduce 3.3.4 讲解~
  • Prometheus-3--Prometheus是怎么抓取Java应用,Redis中间件,服务器环境的指标的?
  • 超详细:2026年博士申请时间线
  • 【Redis】安装Redis,通用命令
  • Redis键值对中值的数据结构
  • 05 基于sklearn的机械学习-梯度下降(下)
  • 解决 “crypto.hash is not a function”:Vite 从 6.x 升级至 7.x 后 `pnpm run dev` 报错问题
  • vue3+vue-flow制作简单可拖拽可增删改流程图
  • JMeter的基本使用教程
  • OpenLayers 详细开发指南 - 第八部分 - GeoJSON 转换与处理工具
  • 《Java Agent与Instrumentation:运行时增强的魔法武器》
  • 为什么ping和dig(nslookup)返回地址不一样,两者的区别
  • 基于C语言实现(控制台 )小区物业管理系统