当前位置: 首页 > article >正文

AI笔记 - 模型调试 - 调试方式

模型调试方式

  • 基础信息
  • 打印模型信息
  • 计算参数量和计算量
    • 过滤原则
    • profile方法
    • get_model_complexity_info方法
    • FlopCountAnalysis方法

基础信息

# 打印执行的设备数量:device_count:1
print(f"device_count:{torch.cuda.device_count()}")# 打印当前网络执行的设备信息:device: cuda:0
print(f"device: {next(self.net.parameters()).device}")  # 应该输出: cuda:0

打印模型信息

#操作	    代码示例
#-----------------------------------------------------
#遍历所有模块	for name, module in model.named_modules():
#-----------------------------------------------------
#打印参数详情	module.named_parameters()
#-----------------------------------------------------
#打印缓冲区	module.named_buffers()
#-----------------------------------------------------
#过滤特定层	isinstance(module, nn.Conv2d)
#-----------------------------------------------------
#统计计算量	profile(module, inputs=(input,))
#-----------------------------------------------------import torchvision.models as modelsmodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
for name, p in model.named_parameters():
print(f"params name:{name}, shape:{p.shape}, device:{p.device}")
print(f"dtype: {p.dtype}, 是否需要梯度:{p.requires_grad}")#params name:conv1.weight, shape:torch.Size([64, 3, 7, 7]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.weight, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.bias, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
...for name, module in model.named_modules():print(f"模块名称:{name}, 模块类型:{type(module).__name__}")# 打印可训练参数(weight/bias)for param_name, param in module.named_parameters():print(f"  - 参数:{param_name} | 形状:{param.shape} | 设备:{param.device} | 需梯度:{param.requires_grad} | 数据类型:{param.dtype}")# 打印缓冲区(如BatchNorm的running_mean)for buffer_name, buffer in module.named_buffers():print(f"  - 缓冲区: {buffer_name} | 形状: {buffer.shape} | 设备: {buffer.device}")# 模块名称:, 模块类型:ResNet
#  - 参数:conv1.weight | 形状:torch.Size([64, 3, 7, 7]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.bias | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0
#  - 缓冲区: layer1.0.bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0......
# 模块名称:layer1.0, 模块类型:Bottleneck
#  - 参数:conv1.weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0......
#模块名称:layer1.0.conv1, 模块类型:Conv2d
#  - 参数:weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#模块名称:layer1.0.bn1, 模块类型:BatchNorm2d
......

计算参数量和计算量

过滤原则

在计算模型计算量(FLOPs)时,过滤掉 BatchNorm2d、Sequential 和 Bottleneck 等非关键层是常见的需求

层类型是否过滤原因
BatchNorm2d✅ 过滤计算量极小(仅逐通道缩放),可忽略
Sequential✅ 过滤容器层(实际计算在子层)
Bottleneck✅ 过滤复合层(计算量已包含在子层中)
Conv2d/Linear❌ 保留核心计算层
ReLU/Pooling⚠️ 可选通常忽略(或单独统计)

profile方法

from thop import profilemodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
flops, params = profile(model, inputs=(input,))
print(f"FLOPs: {flops / 1e9:.2f} G")  # 输出: ~4.11 GFLOPs
print(f"Params: {params / 1e6:.2f} M")  # 输出: ~25.56 Million

get_model_complexity_info方法

from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(self.net,(3, 1280, 1280),  # (channels, height, width)as_strings=True,print_per_layer_stat=True,  # 打印每层计算量verbose=True,
)
print(f"FLOPs: {macs}")
print(f"Params: {params}")# Warning: module IntermediateLayerGetter,FPN,SSH,ClassHead,BboxHead,LandmarkHead,RetinaFace,DataParallel is treated as a zero-op.
# DataParallel(
#  426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#  (module): RetinaFace(
#    426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#    (body): IntermediateLayerGetter(
#      213.07 k, 49.946% Params, 1.45 GMac, 35.733% MACs, 
#      (stage1): Sequential(
#        10.13 k, 2.374% Params, 642.25 MMac, 15.774% MACs, 
#        (0): Sequential(
#          232, 0.054% Params, 98.3 MMac, 2.414% MACs, 
#          (0): Conv2d(216, 0.051% Params, 88.47 MMac, 2.173% MACs, 3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#          (1): BatchNorm2d(16, 0.004% Params, 6.55 MMac, 0.161% MACs, 8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#          (2): LeakyReLU(0, 0.000% Params, 3.28 MMac, 0.080% MACs, negative_slope=0.1, inplace=True)
#        )
#        (1): Sequential()
#  ......
# )
# FLOPs: 4.07 GMac
# Params: 426.61 k

FlopCountAnalysis方法

from fvcore.nn import FlopCountAnalysisflops = FlopCountAnalysis(self.net, image)
flops = flops.unsupported_ops_warnings(False)  # 忽略不支持的操作警告# 计算总 FLOPs
print(flops.by_module())  # 打印每个模块的 FLOPs
total_flops = flops.total()
print(f"Total FLOPs: {total_flops / 1e9:.2f} G")# 打印每一层的 FLOPs,返回字典 {模块名: FLOPs}
print(flops.by_module())# 打印按模块分组的 FLOPs
print(flops.by_module_and_operator())  # 更详细的统计
http://www.lryc.cn/news/2392415.html

相关文章:

  • 榕壹云物品回收系统实战案例:基于ThinkPHP+MySQL+UniApp的二手物品回收小程序开发与优化
  • 《软件工程》第 9 章 - 软件详细设计
  • WebVm:无需安装,一款可以在浏览器运行的 Linux 来了
  • 王树森推荐系统公开课 排序06:粗排模型
  • go并发编程| channel入门
  • PH热榜 | 2025-05-29
  • 详解GPU
  • WPF【11_10】WPF实战-重构与美化(配置Material UI框架)
  • (自用)Java学习-5.16(取消收藏,批量操作,修改密码,用户更新,上传头像)
  • 【Node.js】部署与运维
  • 【Java Web】速通JavaScript
  • TDengine 运维——巡检工具(安装前预配置)
  • C#索引器详解:让对象像数组一样被访问
  • 机器学习课设
  • vue 如何对 div 标签 设置assets内本地背景图片
  • wsl2 docker重启后没了
  • ubuntu 22.04 配置静态IP、网关、DNS
  • RDS PostgreSQL手动删除副本集群副本的步骤
  • MySQL 自增主键重置详解:保持 ID 连续性
  • Vue Hook Store 设计模式最佳实践指南
  • 国产化Word处理控件Spire.Doc教程:通过Java简单快速的将 HTML 转换为 PDF
  • Spring AI 1.0 GA深度解析与最佳实践
  • Java求职面试:从Spring到微服务的技术挑战
  • 鸿蒙OSUniApp 开发的图文混排展示组件#三方框架 #Uniapp
  • WHAT - 学习 WebSocket 实时 Web 开发
  • 5G NTN卫星通信发展现状(截止2025年3月)
  • 【计算机网络】第2章:应用层—DNS
  • [Linux]虚拟地址到物理地址的转化
  • Linux线程入门
  • Kubernetes超详细教程,一篇文章帮助你从零开始学习k8s,从入门到实战