当前位置: 首页 > news >正文

Stacked hourglass networks for human pose estimation代码学习

Stacked hourglass networks for human pose estimation
https://github.com/princeton-vl/pytorch_stacked_hourglass
这是一个用于人体姿态估计的模型,只能检测单个人
作者通过重复的bottom-up(高分辨率->低分辨率)和top-down(低分辨率->高分辨率)以及中间监督(深监督)来提升模型的性能

模型

残差

模型里的残差都是不改变分辨率的
在这里插入图片描述
在这里插入图片描述

class Conv(nn.Module):def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):super(Conv, self).__init__()self.inp_dim = inp_dimself.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=True)self.relu = Noneself.bn = Noneif relu:self.relu = nn.ReLU()if bn:self.bn = nn.BatchNorm2d(out_dim)def forward(self, x):assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass Residual(nn.Module):def __init__(self, inp_dim, out_dim):super(Residual, self).__init__()self.relu = nn.ReLU()self.bn1 = nn.BatchNorm2d(inp_dim)self.conv1 = Conv(inp_dim, out_dim // 2, 1, relu=False)self.bn2 = nn.BatchNorm2d(out_dim // 2)self.conv2 = Conv(out_dim // 2, out_dim // 2, 3, relu=False)self.bn3 = nn.BatchNorm2d(out_dim // 2)self.conv3 = Conv(out_dim // 2, out_dim, 1, relu=False)self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)if inp_dim == out_dim:self.need_skip = Falseelse:self.need_skip = Truedef forward(self, x):  # ([1, inp_dim, H, W])if self.need_skip:residual = self.skip_layer(x)  # ([1, out_dim, H, W])else:residual = x  # ([1, out_dim, H, W])out = self.bn1(x)out = self.relu(out)out = self.conv1(out)  # ([1, out_dim / 2, H, W])out = self.bn2(out)out = self.relu(out)out = self.conv2(out)  # ([1, out_dim / 2, H, W])out = self.bn3(out)out = self.relu(out)out = self.conv3(out)  # ([1, out_dim, H, W])out += residual  # ([1, out_dim, H, W])return out  # ([1, out_dim, H, W])

最前面

首先模型使用了一个卷积核为7∗77*777步长为2的卷积,然后使用了一个残差和下采样,将图像从256∗256256*256256256降到了64∗6464*646464
接着接了两个残差

对应论文这一段
在这里插入图片描述

self.pre = nn.Sequential(  # ([B, 3, 256, 256])Conv(3, 64, 7, 2, bn=True, relu=True),  # ([B, 64, 128, 128])Residual(64, 128),  # ([B, 128, 128, 128])Pool(2, 2),  # ([B, 128, 64, 64])Residual(128, 128),  # ([B, 128, 64, 64])Residual(128, inp_dim)  # ([B, 256, 64, 64]))

在这里插入图片描述

单个Hourglass

在每一次最大池化之前,模型会产生一个分支,一条最大池化,另一条会接卷积(残差)
合并之前,走最大池化的的分支会做一次上采样,然后两个分支按元素加
(对应论文这两句)
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
代码对应这个图
(然而论文的图里最前面的残差不知道怎么算。。。)
在这里插入图片描述

class Hourglass(nn.Module):def __init__(self, n, f, bn=None, increase=0):super(Hourglass, self).__init__()nf = f + increaseself.up1 = Residual(f, f)# Lower branchself.pool1 = Pool(2, 2)self.low1 = Residual(f, nf)self.n = n# Recursive hourglassif self.n > 1:self.low2 = Hourglass(n - 1, nf, bn=bn)else:self.low2 = Residual(nf, nf)self.low3 = Residual(nf, f)self.up2 = nn.Upsample(scale_factor=2, mode='nearest')def forward(self, x):  # ([1, f, H, W])up1 = self.up1(x)  # ([1, f, H, W])pool1 = self.pool1(x)  # ([1, f, H/2, W/2])low1 = self.low1(pool1)  # ([1, nf, H/2, W/2])low2 = self.low2(low1)  # ([1, nf, H/2, W/2])low3 = self.low3(low2)  # ([1, f, H/2, W/2])up2 = self.up2(low3)  # ([1, f, H, W])return up1 + up2  # ([1, f, H, W])

热力图

模型会接两个1∗11*111的卷积来产生热力图(heatmap)
在这里插入图片描述
(虽然不知道为啥代码里还有一个残差)
在这里插入图片描述

中间监督

将前一个Hourglass,heatmap,heatmap之前的特征通过2个1∗11*111的卷积加在一起
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

https://towardsdatascience.com/using-hourglass-networks-to-understand-human-poses-1e40e349fa15#:~:text=Hourglass%20networks%20are%20a%20type,image%20into%20a%20feature%20matrix.
https://medium.com/@monadsblog/stacked-hourglass-networks-14bee8c35678

http://www.lryc.cn/news/22771.html

相关文章:

  • SpringCloud(五)MQ消息队列
  • SQL语法基础汇总
  • 惠普星14Pro电脑开机不了显示错误代码界面怎么办?
  • 顺序表的构造及功能
  • cesium: 绘制线段(008)
  • HTML、CSS学习笔记4(3D转换、动画)
  • java的分布式锁
  • 17- TensorFlow实现手写数字识别 (tensorflow系列) (项目十七)
  • Polkadot 基础
  • spring源码编译
  • 防盗链是什么?带你了解什么是防盗链
  • Linux基础命令-fdisk管理磁盘分区表
  • (四)K8S 安装 Nginx Ingress Controller
  • 高频面试题
  • js 字节数组操作,TCP协议组装
  • JavaScript的引入并执行-包含动态引入与静态引入
  • 第四阶段01-酷鲨商城项目准备
  • Uncaught ReferenceError: jQuery is not defined
  • 面试阿里测开岗,被面试官针对,当场翻脸,把我的简历还给我,疑似被拉黑...
  • 2. 驱动开发--驱动开发环境搭建
  • 《数据库系统概论》学习笔记——第四章 数据库安全
  • 山洪径流过程模拟及洪水危险性评价
  • LeetCode HOT100 (23、32、33)
  • 电力监控仪表主要分类
  • 山野户外定位依赖GPS或者卫星电话就能完成么?
  • SAP 应收应付重组配置
  • 算法练习(八)计数质数(素数)
  • 用反射模拟IOC模拟getBean
  • 【Ap AutoSAR入门与实战开发02】-【Ap_s2s模块01】: s2s的背景
  • C语言数据结构(3)----无头单向非循环链表