当前位置: 首页 > news >正文

卷积神经网络(CNN)的计算量和参数怎么准确估计?

🍉 CSDN 叶庭云https://yetingyun.blog.csdn.net/


在这里插入图片描述

1. 卷积层(Convolutional Layer)

a) 计算量估计:

卷积层的 FLOPs = 2 * H_out * W_out * C_in * C_out * K_h * K_w

详细解释:

  • H_out, W_out:输出特征图的高度和宽度
  • C_in:输入通道数
  • C_out:输出通道数(卷积核数量)
  • K_h, K_w:卷积核的高度和宽度
  • 乘以 2 是因为每次卷积操作包含一次乘法和一次加法

注意:输出特征图的尺寸可以通过以下公式计算:
H_out = (H_in - K_h + 2P) / S + 1
W_out = (W_in - K_w + 2P) / S + 1
其中,H_in 和 W_in 是输入特征图的高度和宽度,P 是填充(padding),S 是步长(stride)。

b) 参数数量估计:

卷积层的参数数 = (K_h * K_w * C_in + 1) * C_out

解释:

  • K_h * K_w * C_in 是每个卷积核的权重数量
  • 加 1 是因为每个卷积核还有一个偏置项(bias)
  • 乘以 C_out 是因为有 C_out 个卷积核

2. 全连接层(Fully Connected Layer)

a) 计算量估计:

全连接层的 FLOPs = 2 * N_in * N_out

解释:

  • N_in:输入神经元数量
  • N_out:输出神经元数量
  • 乘以 2 同样是因为每个连接包含一次乘法和一次加法

b) 参数数量估计:

全连接层的参数数 = (N_in + 1) * N_out

解释:

  • N_in * N_out 是权重的数量
  • 加 1 再乘以 N_out 是因为每个输出神经元有一个偏置项

3. 池化层(Pooling Layer)

a) 计算量估计:

对于最大池化(Max Pooling):FLOPs ≈ H_out * W_out * C * K_h * K_w
对于平均池化(Average Pooling):FLOPs ≈ 2 * H_out * W_out * C * K_h * K_w

解释:

  • H_out, W_out:输出特征图的尺寸
  • C:通道数(与输入相同)
  • K_h, K_w:池化窗口的高度和宽度

b) 参数数量:池化层通常没有可学习的参数

4. 激活函数(Activation Functions)

激活函数的计算量通常较小,但在精确计算时可以考虑:

ReLU 的 FLOPs ≈ H * W * C (仅需要比较操作)
Sigmoid / {/} /Tanh 的 FLOPs 会更多,因为涉及指数计算

5. 批归一化层(Batch Normalization)

a) 计算量估计:

BN 层的 FLOPs ≈ 4 * H * W * C

解释:需要计算均值、方差、归一化和缩放 / {/} / 平移

b) 参数数量:

BN 层的参数数 = 2 * C (每个通道有一个缩放因子和一个平移因子)

6. 总体估算

要估算整个 CNN 的计算量和参数数量,需要:

  1. 分析网络架构中的每一层
  2. 根据上述方法计算每层的 FLOPs 和参数数
  3. 将所有层的结果相加

注意事项:

  • 实际运行时的计算量可能与理论估计有差异,因为现代硬件和优化技术可能会影响实际性能。
  • 某些操作(如数据传输)虽然不直接体现在 FLOPs 中,但也会影响实际运行时间。
  • 在设计神经网络时,平衡计算复杂度和模型性能是很重要的。

http://www.lryc.cn/news/450798.html

相关文章:

  • Ruby基础语法
  • 插入排序C++
  • 修改ID不能用关键字作为ID校验器-elementPlus
  • 一文详解WebRTC、RTSP、RTMP、SRT
  • 全国职业院校技能大赛(大数据赛项)-平台搭建Zookeeper笔记
  • 不同领域神经网络一般选择什么模型作为baseline(基准模型)
  • 华为-IPv6与IPv4网络互通的6to4自动隧道配置实验
  • 【spring中event】事件简单使用
  • leetcode每日一题day19(24.9.29)——买票需要的时间
  • 智源研究院推出全球首个中文大模型辩论平台FlagEval Debate
  • python实用脚本(二):删除xml标签下的指定类别
  • vue3 父子组件调用
  • 线性模型到神经网络
  • 【架构】前台、中台、后台
  • Stable Diffusion 蒙版:填充、原图、潜空间噪声(潜变量噪声)、潜空间数值零(潜变量数值零)
  • ffmpeg录制视频功能
  • 【LeetCode】每日一题 2024_10_1 最低票价(记忆化搜索/DP)
  • [C++] 小游戏 征伐 SLG DNF 0.0.1 版本 zty出品
  • 黑马头条day7-app端文章搜索
  • 嵌入式必懂微控制器选型:STM32、ESP32、AVR与PIC的比较分析
  • Python selenium库学习使用实操二
  • 基于Hive和Hadoop的电信流量分析系统
  • 访问docker容器中服务的接口,报错提示net::ERR_CONNECTION_REFUSED
  • 【mysql相关总结】
  • uniapp 微信小程序 微信支付
  • CSS 效果:实现动态展示双箭头
  • Linux 创建开发用的账户
  • 检查一个CentOS服务器的配置的常用命令
  • Redis 简单的消息队列
  • C++:继承和多态,自定义封装栈,队列