当前位置: 首页 > news >正文

网络剪枝——network-slimming 项目复现

目录

文章目录

  • 目录
  • 网络剪枝——network-slimming 项目复现
    • clone 存储库
    • Baseline
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Sparsity
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Prune
      • vgg
        • 命令
        • 结果
      • resnet
        • 命令
        • 结果
      • densenet
        • 命令
        • 结果
    • Fine-tune
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • 模型大小计算脚本 param_counter.py
    • 结果汇总
      • CIFAR10

网络剪枝——network-slimming 项目复现

  • 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
  • 【作者复现项目】:
  • 通过百度网盘分享的文件:network-slimming-regin.zip
    链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
    提取码:t1z2
    • 仅以 CIFAR-10 为例,CIFAR-100 同理.
    • 提供中文README_zh-CN.md.
    • 包含 CIFAR-10/100 数据集data.cifar10data.cifar100.
    • 解决了 main.py 运行报错问题.
    • 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.

clone 存储库

注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

  • 链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
    提取码:ud89


Baseline

vgg

训练
  • 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

  • 这个报错通常出现在使用 Python 的multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
  • 为了解决这个问题,应 确保多进程代码(即main.py)位于if __name__ == '__main__':保护块内.
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 再次运行命令,又报错:

  • 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用.item()方法.
  • 为了解决这个问题,应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# 在 test 函数中
for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch losspred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)
  • 再次运行命令就正常运行了:

结果
  • Terminal

  • 在 ./logs 生成文件checkpoint.pth.tarmodel_best.pth.tar

resnet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果


Sparsity

vgg

训练
  • 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果


Prune

vgg

命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

  • main.py同理,为了解决这个问题,应 确保多进程代码位于if __name__ == '__main__':保护块内
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 之后就可以正常运行了.

结果
  • Terminal

  • 在./prunes生成文件prune.txtpruned.pth.tar

  • prune.txt中我们可以看到 Number of parametersTest accuracy

resnet

命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet

命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果


Fine-tune

vgg

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果


模型大小计算脚本 param_counter.py

  • 【路径】:./script/param_counter.py
import torchdef load_model(model_path):model = torch.load(model_path, map_location=torch.device('cpu'))return modeldef count_parameters(model_state_dict):total_params = sum(p.numel() for p in model_state_dict.values())return total_paramsdef get_model_parameters(model_path):# 加载模型状态字典model = load_model(model_path)# 模型状态字典存储在 'state_dict' 键下model_state_dict = model['state_dict'] if 'state_dict' in model else model# 计算参数总数total_params = count_parameters(model_state_dict)return total_params
  • main.py中:
from script.param_counter import get_model_parametersdef main():...# 计算 Parametersmodel_path = 'logs/model_best.pth.tar'total_params = get_model_parameters(model_path)print(f'Total parameters in the model: {total_params}')

结果汇总

注:与原项目结果略有差别.

CIFAR10

CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)
Top1 Accuracy(%)93.7293.6033.9893.75
Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.9995.0094.5995.27
Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.1594.3794.1494.48
Parameters1.09M1.09M0.70M0.72M
http://www.lryc.cn/news/421739.html

相关文章:

  • Spring 懒加载的实际应用
  • PyQT 串口改动每次点开时更新串口信息
  • 三级_网络技术_19_路由器的配置及使用
  • 【STM32 Blue Pill编程】-STM32CubeIDE开发环境搭建与点亮LED
  • 【数据结构】六、图:4.图的遍历(深度优先算法DFS、广度优先算法BFS)
  • 29、号外!号外!ERA5再分析数据下载方式更新啦
  • 智能识别,2024年SD卡数据恢复软件的智能进化
  • 浙大数据结构慕课课后题(04-树5 Root of AVL Tree)
  • Golang | Leetcode Golang题解之第331题验证二叉树的前序序列化
  • zdppy+vue3+onlyoffice文档管理系统项目实战 20240812上课笔记
  • 怎么将mov视频转换成mp4?将mov视频转换成mp4的方法
  • 大数据技术——实战项目:广告数仓(第五部分)
  • 计算机毕业设计 家电销售展示平台 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试
  • C# 根据MySQL数据库中数据,批量删除OSS上的垃圾文件
  • Vue3+Element-plus+setup使用vuemap/vue-amap实现高德地图API相关操作
  • Windows配置开机直达桌面并跳过锁屏登录界面在 Windows 10 中添加在启动时自动运行的应用
  • pythonUI自动化007::pytest的组成以及运行
  • 开放式耳机哪个品牌好用又实惠?五大口碑精品分享
  • 代码随想录算法训练营day39||动态规划07:多重背包+打家劫舍
  • WebSocket革新:用PHP实现实时Web通信
  • Python教程(十三):常用内置模块详解
  • Linux 下的进程状态
  • 【设计模式】六大基本原则
  • Selenium网页的滚动
  • 图算法系列1: 图算法的分类有哪些?(上)
  • 零样本学习——从多语言语料库数据中对未学习语言进行语音识别的创新技术
  • ViewStub的原理
  • 十一、Spring AOP
  • 【网络】IP的路径选择——路由控制
  • Unity动画模块 之 2D IK(反向动力学)