当前位置: 首页 > news >正文

深入理解残差网络(ResNet):原理与PyTorch实现

引言:深度神经网络的瓶颈

随着深度学习的发展,研究者发现简单地增加网络层数反而会导致模型性能下降——这种现象称为退化问题(Degradation Problem)。传统深层网络面临的主要挑战:

  1. 梯度消失/爆炸:反向传播时梯度指数级衰减或增大

  2. 训练困难:深层网络难以收敛到理想状态

  3. 性能饱和:深度增加时准确率反而下降

一、残差学习的核心思想

2015年,何恺明团队提出的残差网络(ResNet)通过引入跳跃连接(Skip Connection) 解决了这一难题,核心公式:

y=F(x,{Wi})+xy=F(x,{Wi​})+x

其中:

  • $\mathbf{x}$:输入特征

  • $\mathcal{F}$:残差函数

  • $\mathbf{y}$:输出特征

这种设计允许网络学习输入与输出之间的残差(差值),而非直接映射。当原始映射接近恒等映射时,学习残差$\mathcal{F} = \mathbf{y} - \mathbf{x}$比学习完整映射更容易。

二、残差块结构解析

import torch
import torch.nn as nnclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()# 第一卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二卷积层self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 跳跃连接处理维度变化self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):identity = self.shortcut(x)  # 保留原始输入out = self.conv1(x)out = self.bn1(out)out = nn.ReLU()(out)out = self.conv2(out)out = self.bn2(out)out += identity  # 关键步骤:添加跳跃连接out = nn.ReLU()(out)return out

三、ResNet网络架构

ResNet由多个残差块堆叠而成,不同深度的网络配置:

网络层ResNet-18ResNet-34ResNet-50
conv17×7, 64, stride 2
pool13×3 max pool, stride 2
conv2_x3×3, 64
3×3, 64 ×2
×31×1, 64
3×3, 64
1×1, 256 ×3
conv3_x3×3, 128
3×3, 128 ×2
×41×1, 128
3×3, 128
1×1, 512 ×4
conv4_x3×3, 256
3×3, 256 ×2
×61×1, 256
3×3, 256
1×1, 1024 ×6
conv5_x3×3, 512
3×3, 512 ×2
×31×1, 512
3×3, 512
1×1, 2048 ×3
全连接层1000-d fc

四、PyTorch实现ResNet-18

class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):super().__init__()self.in_channels = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 残差层self.layer1 = self._make_layer(block, 64, layers[0], stride=1)self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512, num_classes)def _make_layer(self, block, out_channels, blocks, stride=1):layers = []# 第一个块可能需要下采样layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels# 剩余块for _ in range(1, blocks):layers.append(block(out_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# 实例化ResNet-18
def resnet18(num_classes=1000):return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

五、关键优势分析

  1. 梯度传播优化:跳跃连接提供梯度高速公路

  2. 恒等映射保障:当残差接近0时自动退化为恒等函数

  3. 参数效率:相比传统网络,参数量更少但性能更好

六、训练技巧与注意事项

  1. 权重初始化:使用He初始化

    for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

  2. 学习率调度:余弦退火策略

  3. 数据增强:随机裁剪、水平翻转

  4. 优化器选择:SGD with momentum (0.9)

http://www.lryc.cn/news/575257.html

相关文章:

  • 搭建自己的WEB应用防火墙
  • RabbitMq中启用NIO
  • 【评估指标】IoU 交并比
  • 工业“三体”联盟:ethernet ip主转profinet网关重塑设备新规则
  • 智哪儿专访 | Matter中国提速:开放标准如何破局智能家居“生态孤岛”?
  • Selenium 二次封装通用页面基类 BasePage —— Python 实践
  • GBDT:梯度提升决策树——集成学习中的预测利器
  • Git上传代码如何解决Merge冲突
  • 时序数据库 TDengine 助力华锐 D5 平台实现“三连降”:查询快了,机器少了,成本也低了
  • 【目标检测】平均精度(AP)与均值平均精度(mAP)计算详解
  • MicroPython网络编程:AP模式与STA模式详解
  • 大塘至浦北高速分布式光伏项目,让‘交通走廊’变身‘绿色能源带’
  • 深度学习入门--(二)感知机
  • python的kivy框架界面布局方法详解
  • react中使用3D折线图跟3D曲面图
  • Vue Devtools “Open in Editor” 配置教程(适用于 VSCode 等主流编辑器)
  • 大语言模型(LLM)初探:核心概念与应用场景
  • 【MongoDB】MongoDB从零开始详细教程 核心概念与原理 环境搭建 基础操作
  • DeepSeek模型接入LangChain流程(详细教程)
  • 永磁同步电机无速度算法--基于同步旋转坐标系锁相环的滑模观测器
  • PYTHON从入门到实践6-字典
  • MCP2518FD发送时有时候多发数据包问题
  • 【预告 大模型应用开发实战专栏 升级】将增加《大模型 Agent 应用实战指南》专题赋能 Agent 开发者
  • OpenGL模板缓冲:实现亮显外轮廓效果
  • C# LINQ语法
  • Python 爬虫入门:从数据爬取到转存 MySQL 数据库
  • Cookie 在 HTTP 中的作用HTTP 中的状态码
  • 北斗导航 | 基于改进奇偶矢量法的CAT I精密进近RAIM算法
  • 半导体芯闻--20250625
  • Linux离线安装jdk-11