当前位置: 首页 > news >正文

用于神经网络的FLOP和Params计算工具

用于神经网络的FLOP和Params计算工具

1. FlopCountAnalysis

pip install fvcore
import torch
from torchvision.models import resnet152, resnet18
from fvcore.nn import FlopCountAnalysis, parameter_count_tablemodel = resnet152(num_classes=1000)tensor = (torch.rand(1, 3, 224, 224),)#分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())def print_model_parm_nums(model):total = sum([param.nelement() for param in model.parameters()])print('  + Number of params: %.2fM' % (total / 1e6))print_model_parm_nums(model)

2. flopth

https://github.com/vra/flopth

pip install flopth 

Running on models in torchvision.models

$ flopth -m alexnet 
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| module_name   | module_type       | in_shape    | out_shape   | params   | params_percent   | params_percent_vis             | flops    | flops_percent   | flops_percent_vis   |
+===============+===================+=============+=============+==========+==================+================================+==========+=================+=====================+
| features.0    | Conv2d            | (3,224,224) | (64,55,55)  | 23.296K  | 0.0381271%       |                                | 70.4704M | 9.84839%        | ####                |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.1    | ReLU              | (64,55,55)  | (64,55,55)  | 0.0      | 0.0%             |                                | 193.6K   | 0.027056%       |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.2    | MaxPool2d         | (64,55,55)  | (64,27,27)  | 0.0      | 0.0%             |                                | 193.6K   | 0.027056%       |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.3    | Conv2d            | (64,27,27)  | (192,27,27) | 307.392K | 0.50309%         |                                | 224.089M | 31.3169%        | ###############     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.4    | ReLU              | (192,27,27) | (192,27,27) | 0.0      | 0.0%             |                                | 139.968K | 0.0195608%      |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.5    | MaxPool2d         | (192,27,27) | (192,13,13) | 0.0      | 0.0%             |                                | 139.968K | 0.0195608%      |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.6    | Conv2d            | (192,13,13) | (384,13,13) | 663.936K | 1.08662%         |                                | 112.205M | 15.6809%        | #######             |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.7    | ReLU              | (384,13,13) | (384,13,13) | 0.0      | 0.0%             |                                | 64.896K  | 0.00906935%     |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.8    | Conv2d            | (384,13,13) | (256,13,13) | 884.992K | 1.44841%         |                                | 149.564M | 20.9018%        | ##########          |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.9    | ReLU              | (256,13,13) | (256,13,13) | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.10   | Conv2d            | (256,13,13) | (256,13,13) | 590.08K  | 0.965748%        |                                | 99.7235M | 13.9366%        | ######              |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.11   | ReLU              | (256,13,13) | (256,13,13) | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| features.12   | MaxPool2d         | (256,13,13) | (256,6,6)   | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| avgpool       | AdaptiveAvgPool2d | (256,6,6)   | (256,6,6)   | 0.0      | 0.0%             |                                | 9.216K   | 0.00128796%     |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.0  | Dropout           | (9216)      | (9216)      | 0.0      | 0.0%             |                                | 0.0      | 0.0%            |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.1  | Linear            | (9216)      | (4096)      | 37.7528M | 61.7877%         | ############################## | 37.7487M | 5.27547%        | ##                  |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.2  | ReLU              | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 4.096K   | 0.000572425%    |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.3  | Dropout           | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 0.0      | 0.0%            |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.4  | Linear            | (4096)      | (4096)      | 16.7813M | 27.4649%         | #############                  | 16.7772M | 2.34465%        | #                   |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.5  | ReLU              | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 4.096K   | 0.000572425%    |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
| classifier.6  | Linear            | (4096)      | (1000)      | 4.097M   | 6.70531%         | ###                            | 4.096M   | 0.572425%       |                     |
+---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+FLOPs: 715.553M
Params: 61.1008M

Running on custom models

# file path: /tmp/my_model.py
# model name:  MyModel
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)self.conv3 = nn.Conv2d(3, 3, kernel_size=3, padding=1)self.conv4 = nn.Conv2d(3, 3, kernel_size=3, padding=1)def forward(self, x1):x1 = self.conv1(x1)x1 = self.conv2(x1)x1 = self.conv3(x1)x1 = self.conv4(x1)return x1
$ flopth -m MyModel -p /tmp/my_model.py -i 3 224 224
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| module_name   | module_type   | in_shape    | out_shape   |   params | params_percent   | params_percent_vis   | flops    | flops_percent   | flops_percent_vis   |
+===============+===============+=============+=============+==========+==================+======================+==========+=================+=====================+
| conv1         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv2         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv3         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
| conv4         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
+---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+FLOPs: 16.8591M
Params: 336.0

3. calflops

https://github.com/MrYxJ/calculate-flops.pytorch/tree/main

pip install calflops
from calflops import calculate_flops
from torchvision import modelsmodel = models.alexnet()
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model, input_shape=input_shape,output_as_string=True,output_precision=4)
print("Alexnet FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))
#Alexnet FLOPs:4.2892 GFLOPS   MACs:2.1426 GMACs   Params:61.1008 M 
  1. from thop import profile

https://github.com/Lyken17/pytorch-OpCounter

pip install thop
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
class YourModule(nn.Module):# your definition
def count_your_model(model, x, y):# your rule hereinput = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model})
http://www.lryc.cn/news/238224.html

相关文章:

  • CUDA核函数,如何设置grid和block即不超过大小又能够遍历整个volume
  • 【Linux】软连接和硬链接:创建、管理和解除链接的操作
  • Matlab群体智能优化算法之海象优化算法(WO)
  • go语言学习-结构体
  • Stable Diffusion进阶玩法说明
  • PDF控件Spire.PDF for .NET【转换】演示:将PDF 转换为 HTML
  • 二分查找——34. 在排序数组中查找元素的第一个和最后一个位置
  • MFC中的主窗口以及如何通过代码找到主窗口
  • Typora下载安装 (Mac和Windows)图文详解
  • 32位单片机PY32F040,主频72M,外设丰富,支持断码LCD
  • Shell判断:模式匹配:case(二)
  • 从android.graphics.Path中取出Point点,Kotlin
  • 力扣C++学习笔记——C++ 给vector去重
  • Flutter笔记:使用相机
  • 包装类型的缓存机制
  • 【BUG】第一次创建vue3+vite项目启动报错Error: Cannot find module ‘worker_threads‘
  • 多目标应用:基于非支配排序的鲸鱼优化算法NSWOA求解微电网多目标优化调度(MATLAB代码)
  • 网络爬虫|Selenium——find_element_by_xpath()的几种方法
  • 【Kingbase FlySync】命令模式:部署双轨并行,并实现切换同步
  • echarts 多toolti同时触发图表实现
  • 2023.11.22使用flask做一个简单的图片浏览器
  • 万字解析设计模式之桥接模式、外观模式
  • 常用系统函数
  • 键盘控制ROS车运动
  • linux上交叉编译qt库
  • Nacos介绍与使用
  • 网工内推 | 字节原厂,正式编,网络工程师,最高30K*15薪
  • Git 远程仓库(Github)
  • Mybatis Plus分页实现逻辑整理(结合芋道整合进行解析)
  • C#编程题分享(2)