当前位置: 首页 > news >正文

PyTorch回忆(三)U-net

# unet_tutorial.py
import torch
import torch.nn as nn
import torch.nn.functional as F# -------- 把打印函数写成一个工具 ----------
def print_shape(name, x):print(f"{name:<12}: {tuple(x.shape)}")# ----------- 基础卷积块:两次(Conv+BN+ReLU) -----------
class DoubleConv(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),nn.BatchNorm2d(out_c),nn.ReLU(inplace=True),nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),nn.BatchNorm2d(out_c),nn.ReLU(inplace=True))def forward(self, x):return self.conv(x)# ----------- 下采样:MaxPool + DoubleConv -----------
class Down(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.pool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_c, out_c))def forward(self, x):return self.pool_conv(x)# ----------- 上采样:转置卷积或插值 + 拼接 + DoubleConv -----------
class Up(nn.Module):def __init__(self, in_c, out_c, bilinear=True):super().__init__()if bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_c, out_c)else:self.up = nn.ConvTranspose2d(in_c, in_c//2, 2, stride=2)self.conv = DoubleConv(in_c, out_c)def forward(self, x1, x2):x1 = self.up(x1)# 处理尺寸不一致diffY = x2.size(2) - x1.size(2)diffX = x2.size(3) - x1.size(3)x1 = F.pad(x1, [diffX//2, diffX-diffX//2,diffY//2, diffY-diffY//2])x = torch.cat([x2, x1], dim=1)return self.conv(x)# ----------- 输出1×1卷积 -----------
class OutConv(nn.Module):def __init__(self, in_c, n_classes):super().__init__()self.conv = nn.Conv2d(in_c, n_classes, 1)def forward(self, x):return self.conv(x)# ----------- 完整 UNet -----------
class UNet(nn.Module):def __init__(self, n_channels=3, n_classes=1, bilinear=True):super().__init__()self.inc   = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)factor = 2 if bilinear else 1self.down4 = Down(512, 1024 // factor)self.up1   = Up(1024, 512 // factor, bilinear)self.up2   = Up(512, 256 // factor, bilinear)self.up3   = Up(256, 128 // factor, bilinear)self.up4   = Up(128, 64, bilinear)self.outc  = OutConv(64, n_classes)def forward(self, x):print_shape("input", x)x1 = self.inc(x)      ; print_shape("x1", x1)x2 = self.down1(x1)   ; print_shape("x2", x2)x3 = self.down2(x2)   ; print_shape("x3", x3)x4 = self.down3(x3)   ; print_shape("x4", x4)x5 = self.down4(x4)   ; print_shape("x5", x5)x = self.up1(x5, x4)  ; print_shape("up1", x)x = self.up2(x, x3)   ; print_shape("up2", x)x = self.up3(x, x2)   ; print_shape("up3", x)x = self.up4(x, x1)   ; print_shape("up4", x)logits = self.outc(x) ; print_shape("logits", logits)return logitsif __name__ == "__main__":model = UNet()print(model)input = torch.randn(2,3,256,256)_ = model(input)
input       : (2, 3, 256, 256)
x1          : (2, 64, 256, 256)   # 卷积后尺寸不变
x2          : (2, 128, 128, 128)  # MaxPool 2×
x3          : (2, 256, 64, 64)
x4          : (2, 512, 32, 32)
x5          : (2, 512, 16, 16)    # 512=1024//2 (bilinear=True)
up1         : (2, 256, 32, 32)    # 上采样2×,与x4拼接→512→256
up2         : (2, 128, 64, 64)
up3         : (2, 64, 128, 128)
up4         : (2, 64, 256, 256)
logits      : (2, 1, 256, 256)    # 输出与原图同尺寸

在这里插入图片描述

  1. nn.MaxPool2d(2)里面是几,就是原来的几分之一,因为默认没有填充,步长等于kernel_size,这里就是尺寸变为原来的一半,输入特征图大小为 (4, 4) → 输出特征图大小为 (2, 2)。
  2. 原论文padding=0,所以每做一次 3×3 卷积,特征图宽/高各减少 2 像素,这里直接保持尺寸不变了
  3. 加factor为了统一双线性插值转置卷积,转置卷积nn.ConvTranspose2d(...)会使通道数减半,双线性插值不会。现在做上采样(Upsampling)时,“先插值再卷积” 已经成了最常用、最稳妥的方案,其次是转置卷积(Transpose Convolution),如果在做 U-Net 或通用语义分割,直接采用 双线性插值 + 卷积 就行;如果做 GAN 或需要学习式上采样,再考虑转置卷积。
  4. nn.ReLU(inplace=True)直接在原始输入张量上进行修改,而不是创建一个新的张量来存储结果。这样可以节省内存
    nn.Conv2d(in_c, out_c, 3, padding=1, bias=False)bias=False是不为卷积层添加可学习的偏置项。BatchNorm 层本身包含一个可学习的偏置参数 β,可以替代卷积层的偏置。在现代 CNN 架构中,如果卷积层后紧跟 BatchNorm,通常都会设置 bias=False。
http://www.lryc.cn/news/620995.html

相关文章:

  • java 学习 贪心 + 若依 + 一些任务工作
  • FTP服务器搭建(Linux)
  • opencv:傅里叶变换有什么用?怎么写傅里叶变换?
  • 软件著作权产生与登记关键点
  • 从单机到分布式:用飞算JavaAI构建可扩展的TCP多人聊天系统
  • 算法基础 第3章 数据结构
  • 数学建模-非线性规划模型
  • 深入理解提示词工程:从入门到精通的AI对话艺术
  • Mybatis实现页面增删改查
  • 数仓分层架构设计全解析:从理论到实践的深度思考
  • 一台联想 ThinkCentre M7100z一体机开机黑屏无显示维修记录
  • 【跨越 6G 安全、防御与智能协作:从APT检测到多模态通信再到AI代理语言革命】
  • 解决“Win7共享文件夹其他电脑网络无法发现共享电脑名称”的问题
  • 机器视觉之图像处理篇
  • c/c++ UNIX 域Socket和共享内存实现本机通信
  • 从概率填充到置信度校准:GPT-5如何从底层重构AI的“诚实”机制
  • 【网络安全测试】手机APP安全测试工具NowSecure 使用指导手册(有关必回)
  • PHP 开发全解析:从基础到实战的进阶之路
  • 【CV 目标检测】R-CNN①——Overfeat
  • GPT-5 提示词优化全攻略:用 Prompt Optimizer 快速迁移与提升,打造更稳更快的智能应用
  • RH134 管理基本存储知识点
  • 【车联网kafka】用钟表齿轮理解 Kafka 时间轮​(第七篇)
  • PlantSimulation知识点2025.8.14
  • pycharm远程连接服务器跑实验详细操作
  • 云计算-Docker Compose 实战:从OwnCloud、WordPress、SkyWalking、Redis ,Rabbitmq等服务配置实例轻松搞定
  • UML函数原型中stereotype的含义,有啥用?
  • UE5 C++ 删除文件
  • 4.Ansible部署文件到主机
  • 配置docker pull走http代理
  • 【网络】HTTP总结复盘