当前位置: 首页 > news >正文

TVRNet网络PyTorch实现

文章目录

    • 文章地址
    • 网络各层结构
    • 代码实现

文章地址

  • An End-to-End Traffic Visibility Regression Algorithm
  • 文章通过训练搜集得到的真实道路图像数据集(Actual Road dense image Dataset, ARD),通过专业的能见度计和多人标注,获得可靠的能见度标签数据集。构建网络,进行训练,获得了较好的能见度识别网络。网络包括特征提取​、多尺度映射​、特征融合​、非线性输出(回归范围为[0,1],需要经过(0,0),(1,1)改用修改的sigmoid函数,相较于ReLU更好)。结构如下​
    在这里插入图片描述

网络各层结构

在这里插入图片描述

  • 我认为红框位置与之相应的参数不匹配,在Feature Extraction部分Reshape之后得到的特征图大小为4124124。紧接着接了一个卷积层Conv,显示输入是3128128
  • 第二处红框,MaxPool的kernel设置为88,特征图没有进行padding,到全连接层的输入变为64117*117,参数不对应
    在这里插入图片描述

代码实现

"""Based on the ideas of the below paper, using PyTorch to build TVRNet.Reference: Qin H, Qin H. An end-to-end traffic visibility regression algorithm[J]. IEEE Access, 2021, 10: 25448-25454.​@weishuo
"""import torch
from torch import nn
import mathclass Inception(nn.Module):def __init__(self, in_planes, out_planes):super(Inception, self).__init__()self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1)self.conv5 = nn.Conv2d(in_planes, out_planes, kernel_size=5, padding=2)self.conv7 = nn.Conv2d(in_planes, out_planes, kernel_size=7, padding=3)def forward(self, x):out_1 = self.conv1(x)out_3 = self.conv3(x)out_5 = self.conv5(x)out_7 = self.conv7(x)out = torch.cat((out_1, out_3, out_5, out_7), dim=1)return outdef modify_sigmoid(x):return 1 / (1 + torch.exp(-10*(x-0.5)))class TVRNet(nn.Module):def __init__(self, in_planes, out_planes):super(TVRNet, self).__init__()# (B, 3, 224, 224)  ——>  (B, 3, 220, 220)self.FeatureExtraction_onestep = nn.Sequential(nn.Conv2d(in_planes, 20, kernel_size=5, padding=0),nn.ReLU(inplace=True),)self.FeatureExtraction_maxpool = nn.MaxPool2d((5, 1))self.MultiScaleMapping = nn.Sequential(Inception(4, 16),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=8))self.FeatureIntegration = nn.Sequential(nn.Linear(46656, 100),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(100, out_planes))self.NonLinearRegression = modify_sigmoiddef forward(self, x):x = self.FeatureExtraction_onestep(x)x = x.view((x.shape[0], 1, x.shape[1], -1))x = self.FeatureExtraction_maxpool(x)x = x.view(x.shape[0], x.shape[2], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))# print(x.shape)x = self.MultiScaleMapping(x)# print(x.shape)x = x.view(x.shape[0], -1)x = self.FeatureIntegration(x)out = self.NonLinearRegression(x)return outif __name__ == '__main__':a = torch.randn(1,3,224,224)net = TVRNet(3,3)b = net(a)print(b.shape)
http://www.lryc.cn/news/208934.html

相关文章:

  • opencv之坑(八)——putText中文乱码解决
  • nrf52832 开发板入手笔记:资料搜集
  • PHP如何批量修改二维数组中值
  • Python 算法高级篇:归并排序的优化与外部排序
  • LeetCode--1991.找到数组的中间位置
  • 物联网数据采集网关连接设备与云平台的关键桥梁
  • 专家级数据恢复:UFS Explorer Professional Recovery Crack
  • 2023/10/23 mysql学习
  • 软考系统架构师知识点集锦六:项目管理
  • MacOS系统Chrome开发者模式下载在线视频
  • uniapp v3+ts 使用 u-upload上传图片以及视频
  • 为什么虚拟dom会提高性能?
  • 2015年亚太杯APMCM数学建模大赛A题海上丝绸之路发展战略的影响求解全过程文档及程序
  • js中HTMLCollection如何循环
  • Kafka - 3.x 副本不完全指北
  • 二分归并法将两个数组合并
  • ROS自学笔记十六:URDF优化_xacro文件
  • XMLHttpRequest拦截请求和响应
  • 前端 读取/导入 Excel文档
  • 聊聊springboot的TomcatMetricsBinder
  • 《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码
  • 2023年第四届MathorCup高校数学建模挑战赛——大数据竞赛B题 实现代码
  • larvel 中的api.php_Laravel 开发 API
  • 虚拟机构建部署单体项目及前后端分离项目
  • C++之特殊类的设计
  • Java练习题2020 -1
  • LuaTable转C#的列表List和字典Dictionary
  • Redis快速上手篇七(集群)
  • Mac 安装nvm
  • python 从mssql取出datetime2类型之后格式化