当前位置: 首页 > news >正文

如何得到深度学习模型的参数量和计算复杂度

1.准备好网络模型代码

import torch
import torch.nn as nn
import torch.optim as optim# BP_36: 输入2个节点,中间层36个节点,输出25个节点
class BP_36(nn.Module):def __init__(self):super(BP_36, self).__init__()self.fc1 = nn.Linear(2, 36)  # 输入2个节点,中间层36个节点self.fc2 = nn.Linear(36, 25)  # 输出25个节点def forward(self, x):x = torch.relu(self.fc1(x))  # 使用ReLU激活函数x = self.fc2(x)return x# BP_64: 输入2个节点,中间层64个节点,输出25个节点
class BP_64(nn.Module):def __init__(self):super(BP_64, self).__init__()self.fc1 = nn.Linear(2, 64)  # 输入2个节点,中间层64个节点self.fc2 = nn.Linear(64, 25)  # 输出25个节点def forward(self, x):x = torch.relu(self.fc1(x))  # 使用ReLU激活函数x = self.fc2(x)return x# Bi-LSTM: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_LSTM(nn.Module):def __init__(self):super(Bi_LSTM, self).__init__()self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向LSTMself.fc1 = nn.Linear(72, 25)  # LSTM的输出72维,经过线性层后输出25个节点def forward(self, x):# x的形状应该是(batch_size, seq_len, input_size)x, _ = self.lstm(x)  # 输出LSTM的结果x = self.fc1(x)return x# Bi-GRU: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_GRU(nn.Module):def __init__(self):super(Bi_GRU, self).__init__()self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向GRUself.fc1 = nn.Linear(72, 25)  # GRU的输出72维,经过线性层后输出25个节点def forward(self, x):# x的形状应该是(batch_size, seq_len, input_size)x, _ = self.gru(x)  # 输出GRU的结果x = self.fc1(x)return x

2.运行计算参数量和复杂度的脚本

import torch
# from net import BP_36
# from net import BP_64
# from net import Bi_LSTM
from net import Bi_GRUfrom ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 统计Transformer模型的参数量和计算复杂度
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)
http://www.lryc.cn/news/514601.html

相关文章:

  • 2025年股指期货每月什么时候交割?
  • 自从学会Git,感觉打开了一扇新大门
  • Ansys Discovery 中的网格划分方法:探索模式
  • 关于 AWTK 和 Weston 在旋转屏幕时的资源消耗问题
  • grouped.get_group((‘B‘, ‘A‘))选择分组
  • HTML——66.单选框
  • Couchbase 和数据湖技术的区别、联系和相关性分析
  • springboot3 性能优化
  • C++之运算符重载详解篇
  • 深度学习应用工程化中的节能减排最佳实践
  • 电脑文件msvcp110.d丢失的解决方法
  • xdoj isbn号码
  • qt的utc时间转本地时间
  • mariadb变更数据存放目录
  • 分布式专题(11)之Zookeeper特性与节点数据类型详解
  • Java项目实战II基于小程序的驾校管理系统(开发文档+数据库+源码)
  • Unity Pico 应用失去焦点后,追踪功能被禁用(原生 UI 界面弹出)
  • 第十四届蓝桥杯Scratch省赛中级组—智能计价器
  • AWS S3文件存储工具类
  • 【leetcode100】二叉树的中序遍历
  • 开源GTKSystem.Windows.Forms框架:C# Winform跨平台运行深度解析
  • C++软件设计模式之责任链模式
  • 021-spring-springmvc-组件
  • 基于SpringBoot和OAuth2,实现通过Github授权登录应用
  • macos 支持外接高分辩率显示器开源控制软件
  • C++26 新特性预览(Preview)
  • MySQL5.7.26-Linux-安装(2024.12)
  • 2025-1-2-sklearn学习(30)模型选择与评估-验证曲线: 绘制分数以评估模型 真珠帘卷玉楼空,天淡银河垂地。
  • 【优选算法】查找总价格为目标值的两个商品
  • 利用 NineData 实现 PostgreSQL 到 Kafka 的高效数据同步