当前位置: 首页 > news >正文

【Python/Pytorch - 网络模型】-- 手把手搭建U-Net模型

在这里插入图片描述
文章目录

文章目录

  • 00 写在前面
  • 01 基于Pytorch版本的UNet代码
  • 02 论文下载

00 写在前面

通过U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂网络模型,有很大的帮助作用。

在01中,可以根据U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、输出卷积层定义、损失函数定义、网络模型定义等。

在模型调试过程中,可以先通过简单测试代码,进行代码调试。

01 基于Pytorch版本的UNet代码

# 库函数调用
import torch
import torch.nn as nn
from network.ops import TotalVariation
from torchvision.models import vgg19# 卷积块定义
class conv_block(nn.Module):def __init__(self,ch_in,ch_out):super(conv_block,self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),#nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),#nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.conv(x)return x# 上采样部分定义
class up_conv(nn.Module):def __init__(self,ch_in,ch_out):super(up_conv,self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),#nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.up(x)return x# 输出卷积层定义
class outconv(nn.Module):def __init__(self, in_ch, out_ch):super(outconv, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),#nn.ReLU(inplace=True),)def forward(self, x):x = self.conv(x)return xclass UNET_MODEL(nn.Module):def __init__(self, img_ch=3, output_ch=1,filter_dim=64):super().__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.Conv1 = conv_block(ch_in=img_ch, ch_out=filter_dim)self.Conv2 = conv_block(ch_in=64, ch_out=128)self.Conv3 = conv_block(ch_in=128, ch_out=256)self.Conv4 = conv_block(ch_in=256, ch_out=512)self.Conv5 = conv_block(ch_in=512, ch_out=1024)self.Up5 = up_conv(ch_in=1024, ch_out=512)self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)self.Up4 = up_conv(ch_in=512, ch_out=256)self.Up_conv4 = conv_block(ch_in=512, ch_out=256)self.Up3 = up_conv(ch_in=256, ch_out=128)self.Up_conv3 = conv_block(ch_in=256, ch_out=128)self.Up2 = up_conv(ch_in=128, ch_out=64)self.Up_conv2 = conv_block(ch_in=128, ch_out=64)self.Conv11 = outconv(64, output_ch)def forward(self, x):# encoding pathx1 = self.Conv1(x)x2 = self.Maxpool(x1)x2 = self.Conv2(x2)x3 = self.Maxpool(x2)x3 = self.Conv3(x3)x4 = self.Maxpool(x3)x4 = self.Conv4(x4)x5 = self.Maxpool(x4)x5 = self.Conv5(x5)# decoding + concat pathd5 = self.Up5(x5)d5 = torch.cat((x4, d5), dim=1)d5 = self.Up_conv5(d5)d4 = self.Up4(d5)d4 = torch.cat((x3, d4), dim=1)d4 = self.Up_conv4(d4)d3 = self.Up3(d4)d3 = torch.cat((x2, d3), dim=1)d3 = self.Up_conv3(d3)d2 = self.Up2(d3)d2 = torch.cat((x1, d2), dim=1)d2 = self.Up_conv2(d2)T2 = self.Conv11(d2)return T2# 损失函数定义
class loss_fun(nn.Module):def __init__(self, regular):super().__init__()self.tv = TotalVariation()self.regular = regulardef forward(self, x, y):ychange = y[:, 0:1, :, :]mask = y[:, 1:2, :, :]return torch.add(torch.mean(torch.pow((x[:,:,:,:] - y[:,2:3,:,:])*ychange, 2)), self.regular* torch.mean(self.tv(x[:, :, :, :]*mask)))class loss_fun_total(nn.Module):def __init__(self, regular):super().__init__()self.tv = TotalVariation()self.regular = regulardef forward(self, x, y):loss1 = torch.mean(torch.pow((x[:,0:1,:,:] - y[:,0:1,:,:]*10), 2))return loss1# 测试代码
if __name__ == '__main__':input_channels = 4output_channels = 1x = torch.ones([32, 4, 256, 256])model = UNET_MODEL(input_channels, output_channels)print('model initialization finished!')f = model(x)print(f)

02 论文下载

U-Net: deep learning for cell counting, detection, and morphometry
U-Net: Convolutional Networks for Biomedical Image Segmentation

http://www.lryc.cn/news/369911.html

相关文章:

  • Ansible-doc 命令
  • 面试题:什么是线程的上下文切换?
  • 【简单讲解Perl语言】
  • 专硕初试科目一样,但各专业的复试线差距不小!江南大学计算机考研考情分析!
  • “华为Ascend 910B AI芯片挑战NVIDIA A100:效能比肩,市场角逐加剧“
  • 针对多智能体协作框架的元编程——METAGPT
  • Django自定义CSS
  • Rust基础学习-标准库
  • django连接达梦数据库
  • Python深度学习基于Tensorflow(17)基于Transformer的图像处理实例VIT和Swin-T
  • 树莓派4B_OpenCv学习笔记5:读取窗口鼠标状态坐标_TrackBar滑动条控件的使用
  • c、c#、c++嵌入式比较?
  • 如何使用ai人工智能作诗?7个软件帮你快速作诗
  • 调用华为API实现语音合成
  • docker实战命令大全
  • Java线程死锁
  • virtual box安装invalid installation directory
  • 概率分析和随机算法
  • 15_2 Linux Shell基础
  • Catia装配体零件复制
  • 实用小工具-python esmre库实现word查找
  • SSM框架整合,内嵌Tomcat。基于注解的方式集成
  • 系统架构设计师【论文-2016年 试题4】: 论微服务架构及其应用(包括写作要点和经典范文)
  • 面试题:String 、StringBuffer 、StringBuilder的区别
  • TLS指纹跟踪网络安全实践(C/C++代码实现)
  • 小白学RAG:大模型 RAG 技术实践总结
  • Doris Connector 结合 Flink CDC 实现 MySQL 分库分表
  • ModbusTCP、TCP/IP都走网线,一样吗?
  • 网络学习(13)|Spring Boot中获取HTTP请求头(Header)内容的详细解析
  • 【漏洞复现】宏景eHR pos_dept_post SQL注入漏洞