当前位置: 首页 > news >正文

UNet网络制作

UNet网络制作

代码参考UNet数据集制作及代码实现_哔哩哔哩_bilibili,根据该UP主的代码,加上我的个人整理和理解。(这个UP主的代码感觉很好,很规范

UNet网络由三部分组成:卷积块,下采样层,上采样层。

卷积块

UNet网络中卷积块进行了两次卷积。

class Conv_Block(nn.Module):def __init__(self, in_channel, out_channel):super(Conv_Block, self).__init__()self.layer = nn.Sequential(# padding_mode = "reflect" 增强特征提取nn.Conv2d(in_channel,out_channel, 3, 1, 1, padding_mode="reflect", bias = False),nn.BatchNorm2d(out_channel), # 二维批归一化层,归一化卷积层的输出。用于加速训练和增强模型的泛化能力。nn.Dropout2d(0.3), # 二维随机失活层,以概率0.3随机抑制特征,用于防止过拟合。nn.LeakyReLU(), # 带有负斜率的修正线性单元激活函数,引入非线性变换。nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode="reflect", bias = False),nn.BatchNorm2d(out_channel),nn.Dropout(0.3),nn.LeakyReLU())def forward(self, x):return self.layer(x) 

下采样层

UNet网络的下采样层中进行了一次卷积。下采样将图像大小减半,通道数不变,同时保留更多的重要特征。

class DownSample(nn.Module):def __init__(self, channel):super(DownSample, self).__init__()self.layer = nn.Sequential(nn.Conv2d(channel, channel, 3, 2, 1, padding_mode="reflect", bias = False),nn.BatchNorm2d(channel),nn.LeakyReLU())def forward(self, x):return self.layer(x)

上采样层

UNet网络的上采样层中进行了一次卷积操作和双线性插值上采样。卷积用于减低通道数,并将其与上一层的特征图进行拼接。用于恢复图像大小,同时提取更加精细的特征。

class UpSample(nn.Module):def __init__(self, channel):super(UpSample, self).__init__()self.layer = nn.Conv2d(channel, channel // 2, 1, 1)def forward(self, x, feature_map):up = F.interpolate(x, scale_factor=2, mode="nearest")out = self.layer(up)return torch.cat((out, feature_map), dim = 1)

网络模型

UNet网络模型由编码器和解码器两部分组成。编码器包含了四个 Conv_Block 和四个 DownSample 层,用于逐步提取图像的高级特征。解码器包含了四个 UpSample 和四个 Conv_Block 层,用于通过上采样和特征融合从编码器中恢复图像的细节。最后通过一个卷积层和 Sigmoid 激活函数得到二分类输出,用于分割图像。

class UNet(nn.Module):def __init__(self):super(UNet,self).__init__()self.c1 = Conv_Block(3, 64)self.d1 = DownSample(64)self.c2 = Conv_Block(64, 128)self.d2 = DownSample(128)self.c3 = Conv_Block(128, 256)self.d3 = DownSample(256)self.c4 = Conv_Block(256,512)self.d4 = DownSample(512)self.c5 = Conv_Block(512, 1024)self.u1 = UpSample(1024)self.c6 = Conv_Block(1024, 512)self.u2 = UpSample(512)self.c7 = Conv_Block(512, 256)self.u3 = UpSample(256)self.c8 = Conv_Block(256, 128)self.u4 = UpSample(128)self.c9 = Conv_Block(128, 64)self.out = nn.Conv2d(64,3,3,1,1)# 二分类self.Th = nn.Sigmoid()def forward(self, x):R1 = self.c1(x)R2 = self.c2(self.d1(R1))R3 = self.c3(self.d2(R2))R4 = self.c4(self.d3(R3))R5 = self.c5(self.d4(R4))O1 = self.c6(self.u1(R5, R4))O2 = self.c7(self.u2(O1, R3))O3 = self.c8(self.u3(O2, R2))O4 = self.c9(self.u4(O3, R1))return self.Th(self.out(O4))

测试

一致则正确。

if __name__ == "__main__":x = torch.randn(2,3,256,256)net=UNet()print(net(x).shape)

UNet网络训练

http://www.lryc.cn/news/175970.html

相关文章:

  • 智能热水器丨打造智能家居新体验
  • Python 十进制转化二进制1.0(简易版)
  • WebGL 选中一个表面
  • open ai chartgpt 安装插件 txyz.ai
  • 【算法思想】贪心
  • freeswitch-01
  • Zookeeper-集群介绍与核心理论
  • 动态分配的内存位置在哪里?
  • Vue3中的Ref与Reactive:深入理解响应式编程
  • Windows10/11显示文件扩展名 修改文件后缀名教程
  • 【C++】手撕string(string的模拟实现)
  • 用python3编译cv_bridge
  • 招商信诺人寿基于 Apache Doris 统一 OLAP 技术栈实践
  • 我的python安装在哪儿了?python安装路径怎么查?
  • 视频汇聚/安防监控平台EasyCVR指定到新的硬盘进行存储录像,如何自动挂载该磁盘?
  • 读博时的建议或心得
  • 3分钟,免费制作一个炫酷实用的数据可视化大屏!
  • 自费访学|金融公司高管赴世界名校伯克利交流
  • Databend 开源周报第112期
  • 如何学习maya mel语言的经验分享
  • 睿趣科技:新手抖音开店卖什么产品好
  • 【新版】系统架构设计师 - 案例分析 - 架构设计<Web架构>
  • 竞赛选题 基于视觉的身份证识别系统
  • git详细教程
  • [old]TeamDev DotNetBrowser Crack
  • Zynq-Linux移植学习笔记之63- linux内核崩溃的重启
  • 【精华】ubuntu编译openpose
  • 第二届全国高校计算机技能竞赛——Java赛道
  • 使用Webpack设置TS引用模块,解决Module not found: Error: Can‘t resolve ‘./m1‘ in ...问题
  • 北斗GPS网络时钟系统(子母钟系统)助力智慧教室建设