当前位置: 首页 > news >正文

8.Mobilenetv2网络代码实现

代码如下:

import math
import os
import numpy as npimport torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo#1.建立带有bn的卷积网络
def conv_bn(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp,oup,3,stride,bias=False),nn.BatchNorm2d(oup),nn.ReLU6(inplace=True))#2.建立卷积核是1x1的卷积网络
def conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp,oup,1,1,0,bias=False),nn.BatchNorm2d(oup),nn.ReLU6(inplace=True))class InvertedResidual(nn.Module):def __init__(self, inp, oup, stride, expand_ratio):super(InvertedResidual,self).__init__()self.stride=strideassert stride in [1,2]hidden_dim=round(inp*expand_ratio)self.use_res_connect=self.stride==1 and inp==oupif expand_ratio == 1:self.conv=nn.Sequential(# --------------------------------------------##   进行3x3的逐层卷积,进行跨特征点的特征提取# --------------------------------------------#nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.ReLU6(inplace=True),# -----------------------------------##   利用1x1卷积进行通道数的调整# -----------------------------------#nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)else:self.conv=nn.Sequential(# -----------------------------------##   利用1x1卷积进行通道数的上升# -----------------------------------#nn.Conv2d(inp,hidden_dim,1,1,0,bias=False),nn.BatchNorm2d(hidden_dim),nn.ReLU6(inplace=True),# --------------------------------------------##   进行3x3的逐层卷积,进行跨特征点的特征提取# --------------------------------------------#nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.ReLU6(inplace=True),# -----------------------------------##   利用1x1卷积进行通道数的下降# -----------------------------------#nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup))def forward(self,x):if self.use_res_connect:return x+self.conv(x)else:return self.conv(x)#搭建MobileNetV2网络
class MobileNetV2(nn.Module):def __init__(self, n_class=1000, input_size=224, width_mult=1.):super(MobileNetV2, self).__init__()block=InvertedResidualinput_channel=32last_channel=1280interverted_residual_setting = [# t, c, n, s[1, 16, 1, 1],  # 256, 256, 32 -> 256, 256, 16[6, 24, 2, 2],  # 256, 256, 16 -> 128, 128, 24   2[6, 32, 3, 2],  # 128, 128, 24 -> 64, 64, 32     4[6, 64, 4, 2],  # 64, 64, 32 -> 32, 32, 64       7[6, 96, 3, 1],  # 32, 32, 64 -> 32, 32, 96[6, 160, 3, 2],  # 32, 32, 96 -> 16, 16, 160     14[6, 320, 1, 1],  # 16, 16, 160 -> 16, 16, 320]assert input_size % 32 == 0input_channel = int(input_channel * width_mult)self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel# 512, 512, 3 -> 256, 256, 32self.features=[conv_bn(3,input_channel,2)]for t,c,n,s in interverted_residual_setting:output_channel=int(c*width_mult)for i in range(n):if i==0:self.features.append(block(input_channel,output_channel,s, expand_ratio=t))else:self.features.append(block(input_channel,output_channel,1, expand_ratio=t))# input_channel修改为该轮的输出层数input_channel = output_channelself.features.append(conv_1x1_bn(input_channel, self.last_channel))self.features=nn.Sequential(*self.features)self.classifier=nn.Sequential(nn.Dropout(0.2),nn.Linear(self.last_channel,n_class))self._initialize_weights()def forward(self,x):x=self.features(x)x=x.mean(3).mean(2)x=self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):n = m.weight.size(1)m.weight.data.normal_(0, 0.01)m.bias.data.zero_()if __name__ == '__main__':print("........................................")#数据集生成input=torch.randn(1,3,224,224)print(input.shape)#MobileNetV2的输出ss=MobileNetV2()# print(ss)output=ss(input)print(output.shape)

http://www.lryc.cn/news/192351.html

相关文章:

  • Spring Boot Controller
  • 在网络安全、爬虫和HTTP协议中的重要性和应用
  • Web测试框架SeleniumBase
  • jvm打破砂锅问到底- 为什么要标记或记录跨代引用
  • 小程序长期订阅
  • Studio One6.5中文版本版下载及功能介绍
  • 07-Zookeeper分布式一致性协议ZAB源码剖析
  • 云原生安全应用场景有哪些?
  • Step 1 搭建一个简单的渲染框架
  • Excel 插入和提取超链接
  • 基础架构开发-操作系统、编译器、云原生、嵌入式、ic
  • C++-Mongoose(3)-http-server-https-restful
  • git多分支、git远程仓库、ssh方式连接远程仓库、协同开发(避免冲突)、解决协同冲突(多人在同一分支开发、 合并分支)
  • ChatGPT或将引发现代知识体系转变
  • 【爬虫实战】用pyhon爬百度故事会专栏
  • 焦炭反应性及反应后强度试验方法
  • 链表(3):双链表
  • 【TES720D】基于复旦微的FMQL20S400全国产化ARM核心模块
  • Python 列表切片陷阱:引用、复制与深复制
  • macbook电脑删除app怎么才能彻底清理?
  • 【数据结构】二叉树--链式结构的实现 (遍历)
  • reids基础数据结构
  • gitlab 维护
  • ABB机器人RWS连接方法
  • Spring Boot的循环依赖问题
  • postgresql|数据库|恢复备份的时候报错:pg_restore: implied data-only restore的处理方案
  • Elasticsearch:使用 Langchain 和 OpenAI 进行问答
  • 安全巡检管理系统—隐患排查治理
  • 第9期ThreadX视频教程:自制个微秒分辨率任务调度实现方案(2023-10-11)
  • C++ 11 lamdba表达式详解