当前位置: 首页 > news >正文

人群计数CSRNet的pytorch实现

本文中对CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes(CVPR 2018)中的模型进行pytorch实现

import torch;import torch.nn as nn
from torchvision.models import vgg16
vgg=vgg16(pretrained=1)import warnings
warnings.filterwarnings("ignore")
vgg10=torch.nn.Sequential(torch.nn.Conv2d(3,64,3,stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(64, 64, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(64, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(128, 128, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),torch.nn.Conv2d(128, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(2,2),  #尝试不进行下采样以达到不进行上采样torch.nn.Conv2d(256, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),torch.nn.ReLU(inplace=True),)
class CSRNET(torch.nn.Module):def __init__(self, load_weights=False):super(CSRNET,self).__init__()self.vgg10=vgg10self.dconv1 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1, padding=2)self.dconv2 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1, padding=2)self.dconv3 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1, padding=2)self.dconv4 = torch.nn.Conv2d(512, 256, 3, dilation=2, stride=1, padding=2)self.dconv5 = torch.nn.Conv2d(256, 128, 3, dilation=2, stride=1, padding=2)self.dconv6 = torch.nn.Conv2d(128, 64, 3, dilation=2, stride=1, padding=2)self.finalconv=torch.nn.Conv2d(64,1,1)self.relu=torch.nn.functional.reluif not load_weights:self.vgg10.load_state_dict(vgg.features[0:23].state_dict())def forward(self,x):y=self.vgg10(x)y = self.relu(self.dconv1(y))y = self.relu(self.dconv1(y))y = self.relu(self.dconv2(y))y = self.relu(self.dconv3(y))y = self.relu(self.dconv4(y))y = self.relu(self.dconv5(y))y = self.relu(self.dconv6(y))h=self.finalconv(y)
http://www.lryc.cn/news/249882.html

相关文章:

  • 【HTTP协议】简述HTTP协议的概念和特点
  • 经典神经网络——AlexNet模型论文详解及代码复现
  • flutter开发实战-轮播Swiper更改Custom_layout样式中Widget层级
  • 【Flutter】graphic图表实现自定义tooltip
  • 手机上的记事本怎么打开?安卓手机通用的记事本APP
  • 一起学docker系列之十五深入了解 Docker Network:构建容器间通信的桥梁
  • 前端OFD文件预览(vue案例cafe-ofd)
  • Java[list/set]通用遍历方法之Iterator
  • ubuntu/vscode下的c/c++开发之-CMake语法与练习
  • Java(119):ExcelUtil工具类(org.apache.poi读取和写入Excel)
  • Kong处理web服务跨域
  • Kotlin学习——kt里的作用域函数scope function,let,run,with,apply,also
  • informer辅助笔记:utils/timefeatures.py
  • [Verilog语法]:===和!==运算符使用注意事项
  • mybatis 高并发查询性能问题
  • 我在Vscode学OpenCV 图像处理一(阈值处理、形态学操作【连通性,腐蚀和膨胀,开闭运算,礼帽和黑帽,内核】)
  • Yolov8实现瓶盖正反面检测
  • GAN:WGAN前作
  • 数据库应用:MongoDB 文档与索引管理
  • Python批处理PDF文件,PDF附件轻松批量提取
  • Python可迭代对象排序:深入排序算法与定制排序
  • 基于matlab的图像去噪算法设计与实现
  • NFTScan 正式上线 Starknet NFTScan 浏览器和 NFT API 数据服务
  • 2023年亚太杯APMCM数学建模大赛A题水果采摘机器人的图像识别
  • mysql which is not in SELECT list; this is incompatible with DISTINCT解决方案
  • linux /proc 文件系统
  • java开发之个微群聊自动添加好友
  • Git .gitignore 忽略文件不生效解决方法
  • 【Java】16. HashMap
  • KMP基础架构