当前位置: 首页 > news >正文

深度学习基础之参数量(3)

一般的CNN网络的参数量估计代码

class ResidualBlock(nn.Module):def __init__(self, in_planes, planes, norm_fn='group', stride=1):super(ResidualBlock, self).__init__()print(in_planes, planes, norm_fn, stride)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)self.relu = nn.ReLU(inplace=True)num_groups = planes // 8if norm_fn == 'group':self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)if not stride == 1:self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)elif norm_fn == 'batch':self.norm1 = nn.BatchNorm2d(planes)self.norm2 = nn.BatchNorm2d(planes)if not stride == 1:self.norm3 = nn.BatchNorm2d(planes)elif norm_fn == 'instance':self.norm1 = nn.InstanceNorm2d(planes)self.norm2 = nn.InstanceNorm2d(planes)if not stride == 1:self.norm3 = nn.InstanceNorm2d(planes)elif norm_fn == 'none':self.norm1 = nn.Sequential()self.norm2 = nn.Sequential()if not stride == 1:self.norm3 = nn.Sequential()if stride == 1:self.downsample = Noneelse:self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)def forward(self, x):print(x.shape)#exit()y = xy = self.relu(self.norm1(self.conv1(y)))y = self.relu(self.norm2(self.conv2(y)))if self.downsample is not None:x = self.downsample(x)return self.relu(x + y)R=ResidualBlock(384, 384, norm_fn='instance', stride=1)
summary(R.to("cuda" if torch.cuda.is_available() else "cpu"), (384, 32, 32))

transformer结构的参数量的估计结果

import torch
import torch.nn as nn
from thop import profile
from torchsummary import summary# 定义一个简单的Transformer模型
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.embedding = nn.Embedding(input_dim, hidden_dim)self.transformer_layers = nn.Transformer(d_model=hidden_dim,nhead=num_heads,num_encoder_layers=num_layers,num_decoder_layers=num_layers)self.fc = nn.Linear(hidden_dim, input_dim)def forward(self, src, tgt):src = self.embedding(src)tgt = self.embedding(tgt)output = self.transformer_layers(src, tgt)output = self.fc(output)return output# 创建Transformer模型实例
model2 = Transformer(input_dim=512, hidden_dim=512, num_heads=8, num_layers=6)# 使用thop进行FLOPS估算
flops, params = profile(model2, inputs=(torch.randint(0, 512, (128,)), torch.randint(0, 512, (64,))))
print(f"FLOPS: {flops / 1e9} G FLOPS")  # 打印FLOPS,以十亿FLOPS(GFLOPS)为单位# 计算参数量并打印
num_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {num_params}")

http://www.lryc.cn/news/187964.html

相关文章:

  • 红队专题-从零开始VC++远程控制软件RAT-C/S-[2]界面编写及上线
  • 磁盘满了对日志打印(Logback)的影响
  • 【算法与数据结构】--算法基础--数据结构概述
  • QECon大会亮相产品,全栈测试平台推荐:RunnerGo
  • 前端小案例-图片存放在远端服务器
  • 【鼠标右键菜单添加用VSCode打开文件或文件夹】
  • 【jvm--堆】
  • 【数据库——MySQL(实战项目1)】(1)图书借阅系统
  • [GXYCTF 2019]Ping Ping Ping题目解析
  • HTTP协议的请求协议和响应协议的组成,HTTP常见的状态信息
  • 【LeetCode】剑指 Offer Ⅱ 第6章:栈(6道题) -- Java Version
  • vue3的element-plus的el-dialog的样式修改无效问题
  • 归纳所猜半结论推出完整结论:CF1592F1
  • WPFdatagrid结合comboBox
  • Markdown类图之继承、实现、关联、依赖、组合、聚合总结(十五)
  • @MultipartConfig注解
  • Python并发编程简介
  • WebSocket介绍及部署
  • 自动求导,计算图示意图及pytorch实现
  • 睿伴科创上线了
  • 域名抢注和域名注册
  • 【20】c++设计模式——>组合模式
  • Jetpack:004-如何使用文本组件
  • JVM(八股文)
  • C#WPF标记扩展应用实例
  • 四维曲面如何画?matlab
  • 软件培训测试高级工程师多测师肖sir__html之作业11
  • 详解一典型的反激式开关电源方案
  • AI 大框架基于python来实现基带处理之TensorFlow(信道估计和预测模型,信号解调和解码模型)
  • 阿里云上了新闻联播