当前位置: 首页 > news >正文

yolov8剪枝实践

本文使用的剪枝库是torch-pruning ,实验了该库的三个剪枝算法GroupNormPruner、BNScalePruner和GrowingRegPruner。

安装使用

  1. 安装依赖库
pip install torch-pruning 
  1. 把 https://github.com/VainF/Torch-Pruning/blob/master/examples/yolov8/yolov8_pruning.py,文件拷贝到yolov8的根目录下。或者使用我的剪枝代码,在原有的基础上稍作修改,保存了不同剪枝阶段的模型。
# This code is adapted from Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147), implemented by @Hyunseok-Kim0.
import argparse
import math
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Unionimport numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from ultralytics import YOLO, __version__
from ultralytics.nn.modules import Detect, C2f, Conv, Bottleneck
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.engine.model import TASK_MAP
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import yaml_load, LOGGER, RANK, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import initialize_weights, de_parallelimport torch_pruning as tpdef save_pruning_performance_graph(x, y1, y2, y3):"""Draw performance change graphParameters----------x : ListParameter numbers of all pruning stepsy1 : ListmAPs after fine-tuning of all pruning stepsy2 : ListMACs of all pruning stepsy3 : ListmAPs after pruning (not fine-tuned) of all pruning stepsReturns-------"""try:plt.style.use("ggplot")except:passx, y1, y2, y3 = np.array(x), np.array(y1), np.array(y2), np.array(y3)y2_ratio = y2 / y2[0]# create the figure and the axis objectfig, ax = plt.subplots(figsize=(8, 6))# plot the pruned mAP and recovered mAPax.set_xlabel('Pruning Ratio')ax.set_ylabel('mAP')ax.plot(x, y1, label='recovered mAP')ax.scatter(x, y1)ax.plot(x, y3, color='tab:gray', label='pruned mAP')ax.scatter(x, y3, color='tab:gray')# create a second axis that shares the same x-axisax2 = ax.twinx()# plot the second set of dataax2.set_ylabel('MACs')ax2.plot(x, y2_ratio, color='tab:orange', label='MACs')ax2.scatter(x, y2_ratio, color='tab:orange')# add a legendlines, labels = ax.get_legend_handles_labels()lines2, labels2 = ax2.get_legend_handles_labels()ax2.legend(lines + lines2, labels + labels2, loc='best')ax.set_xlim(105, -5)ax.set_ylim(0, max(y1) + 0.05)ax2.set_ylim(0.05, 1.05)# calculate the highest and lowest points for each set of datamax_y1_idx = np.argmax(y1)min_y1_idx = np.argmin(y1)max_y2_idx = np.argmax(y2)min_y2_idx = np.argmin(y2)max_y1 = y1[max_y1_idx]min_y1 = y1[min_y1_idx]max_y2 = y2_ratio[max_y2_idx]min_y2 = y2_ratio[min_y2_idx]# add text for the highest and lowest values near the pointsax.text(x[max_y1_idx], max_y1 - 0.05, f'max mAP = {max_y1:.2f}', fontsize=10)ax.text(x[min_y1_idx], min_y1 + 0.02, f'min mAP = {min_y1:.2f}', fontsize=10)ax2.text(x[max_y2_idx], max_y2 - 0.05, f'max MACs = {max_y2 * y2[0] / 1e9:.2f}G', fontsize=10)ax2.text(x[min_y2_idx], min_y2 + 0.02, f'min MACs = {min_y2 * y2[0] / 1e9:.2f}G', fontsize=10)plt.title('Comparison of mAP and MACs with Pruning Ratio')plt.savefig('pruning_perf_change.png')def infer_shortcut(bottleneck):c1 = bottleneck.cv1.conv.in_channelsc2 = bottleneck.cv2.conv.out_channelsreturn c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.addclass C2f_v2(nn.Module):# CSP Bottleneck with 2 convolutionsdef __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()self.c = int(c2 * e)  # hidden channelsself.cv0 = Conv(c1, self.c, 1, 1)self.cv1 = Conv(c1, self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):# y = list(self.cv1(x).chunk(2, 1))y = [self.cv0(x), self.cv1(x)]y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def transfer_weights(c2f, c2f_v2):c2f_v2.cv2 = c2f.cv2c2f_v2.m = c2f.mstate_dict = c2f.state_dict()state_dict_v2 = c2f_v2.state_dict()# Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2old_weight = state_dict['cv1.conv.weight']half_channels = old_weight.shape[0] // 2state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]# Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:old_bn = state_dict[f'cv1.bn.{bn_key}']state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]# Transfer remaining weights and buffersfor key in state_dict:if not key.startswith('cv1.'):state_dict_v2[key] = state_dict[key]# Transfer all non-method attributesfor attr_name in dir(c2f):attr_value = getattr(c2f, attr_name)if not callable(attr_value) and '_' not in attr_name:setattr(c2f_v2, attr_name, attr_value)c2f_v2.load_state_dict(state_dict_v2)def replace_c2f_with_c2f_v2(module):for name, child_module in module.named_children():if isinstance(child_module, C2f):# Replace C2f with C2f_v2 while preserving its parametersshortcut = infer_shortcut(child_module.m[0])c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,n=len(child_module.m), shortcut=shortcut,g=child_module.m[0].cv2.conv.groups,e=child_module.c / child_module.cv2.conv.out_channels)transfer_weights(child_module, c2f_v2)setattr(module, name, c2f_v2)else:replace_c2f_with_c2f_v2(child_module)def save_model_v2(self: BaseTrainer):"""Disabled half precision saving. originated from ultralytics/yolo/engine/trainer.py"""ckpt = {'epoch': self.epoch,'best_fitness': self.best_fitness,'model': deepcopy(de_parallel(self.model)),'ema': deepcopy(self.ema.ema),'updates': self.ema.updates,'optimizer': self.optimizer.state_dict(),'train_args': vars(self.args),  # save as dict'date': datetime.now().isoformat(),'version': __version__}# Save last, best and deletetorch.save(ckpt, self.last)if self.best_fitness == self.fitness:torch.save(ckpt, self.best)if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')del ckptdef final_eval_v2(self: BaseTrainer):"""originated from ultralytics/yolo/engine/trainer.py"""for f in self.last, self.best:if f.exists():strip_optimizer_v2(f)  # strip optimizersif f is self.best:LOGGER.info(f'\nValidating {f}...')self.metrics = self.validator(model=f)self.metrics.pop('fitness', None)self.run_callbacks('on_fit_epoch_end')def strip_optimizer_v2(f: Union[str, Path] = 'best.pt', s: str = '') -> None:"""Disabled half precision saving. originated from ultralytics/yolo/utils/torch_utils.py"""x = torch.load(f, map_location=torch.device('cpu'))args = {**DEFAULT_CFG_DICT, **x['train_args']}  # combine model args with default args, preferring model argsif x.get('ema'):x['model'] = x['ema']  # replace model with emafor k in 'optimizer', 'ema', 'updates':  # keysx[k] = Nonefor p in x['model'].parameters():p.requires_grad = Falsex['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys# x['model'].args = x['train_args']torch.save(x, s or f)mb = os.path.getsize(s or f) / 1E6  # filesizeLOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")def train_v2(self: YOLO, pruning=False, **kwargs):"""Disabled loading new model when pruning flag is set. originated from ultralytics/yolo/engine/model.py"""self._check_is_pytorch_model()if self.session:  # Ultralytics HUB sessionif any(kwargs):LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')kwargs = self.session.train_argsoverrides = self.overrides.copy()overrides.update(kwargs)if kwargs.get('cfg'):LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")overrides = yaml_load(check_yaml(kwargs['cfg']))overrides['mode'] = 'train'if not overrides.get('data'):raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")if overrides.get('resume'):overrides['resume'] = self.ckpt_pathself.task = overrides.get('task') or self.taskself.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)if not pruning:if not overrides.get('resume'):  # manually set model only if not resumingself.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)self.model = self.trainer.modelelse:# pruning modeself.trainer.pruning = Trueself.trainer.model = self.model# replace some functions to disable half precision savingself.trainer.save_model = save_model_v2.__get__(self.trainer)self.trainer.final_eval = final_eval_v2.__get__(self.trainer)self.trainer.hub_session = self.session  # attach optional HUB sessionself.trainer.train()# Update model and cfg after trainingif RANK in (-1, 0):self.model, _ = attempt_load_one_weight(str(self.trainer.best))self.overrides = self.model.argsself.metrics = getattr(self.trainer.validator, 'metrics', None)def prune(args):# load trained yolov8 modelbase_name = 'prune/' + str(datetime.now()) + '/'model = YOLO(args.model)model.__setattr__("train_v2", train_v2.__get__(model))pruning_cfg = yaml_load(check_yaml(args.cfg))batch_size = pruning_cfg['batch']# use coco128 dataset for 10 epochs fine-tuning each pruning iteration step# this part is only for sample code, number of epochs should be included in config filepruning_cfg['data'] = "./ultralytics/datasets/soccer.yaml"pruning_cfg['epochs'] = 4model.model.train()replace_c2f_with_c2f_v2(model.model)initialize_weights(model.model)  # set BN.eps, momentum, ReLU.inplacefor name, param in model.model.named_parameters():param.requires_grad = Trueexample_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)# do validation before pruning modelpruning_cfg['name'] = base_name+f"baseline_val"pruning_cfg['batch'] = 128validation_model = deepcopy(model)metric = validation_model.val(**pruning_cfg)init_map = metric.box.mapmacs_list.append(base_macs)nparams_list.append(100)map_list.append(init_map)pruned_map_list.append(init_map)print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")# prune same ratio of filter based on initial sizech_sparsity = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)for i in range(args.iterative_steps):model.model.train()for name, param in model.model.named_parameters():param.requires_grad = Trueignored_layers = []unwrapped_parameters = []for m in model.model.modules():if isinstance(m, (Detect,)):ignored_layers.append(m)example_inputs = example_inputs.to(model.device)pruner = tp.pruner.GroupNormPruner(model.model,example_inputs,importance=tp.importance.GroupNormImportance(),  # L2 norm pruning,iterative_steps=1,ch_sparsity=ch_sparsity,ignored_layers=ignored_layers,unwrapped_parameters=unwrapped_parameters)# Test regularization#output = model.model(example_inputs)#(output[0].sum() + sum([o.sum() for o in output[1]])).backward()#pruner.regularize(model.model)pruner.step()# pre fine-tuning validationpruning_cfg['name'] = base_name+f"step_{i}_pre_val"pruning_cfg['batch'] = 128validation_model.model = deepcopy(model.model)metric = validation_model.val(**pruning_cfg)pruned_map = metric.box.mappruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))current_speed_up = float(macs_list[0]) / pruned_macsprint(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "f"mAP={pruned_map}, speed up={current_speed_up}")# fine-tuningfor name, param in model.model.named_parameters():param.requires_grad = Truepruning_cfg['name'] = base_name+f"step_{i}_finetune"pruning_cfg['batch'] = batch_size  # restore batch sizemodel.train_v2(pruning=True, **pruning_cfg)# post fine-tuning validationpruning_cfg['name'] = base_name+f"step_{i}_post_val"pruning_cfg['batch'] = 128validation_model = YOLO(model.trainer.best)metric = validation_model.val(**pruning_cfg)current_map = metric.box.mapprint(f"After fine tuning mAP={current_map}")macs_list.append(pruned_macs)nparams_list.append(pruned_nparams / base_nparams * 100)pruned_map_list.append(pruned_map)map_list.append(current_map)# remove pruner after single iterationdel prunermodel.model.zero_grad() # Remove gradientssave_path = 'runs/detect/'+base_name+f"step_{i}_pruned_model.pth"torch.save(model.model,save_path) # without .state_dictprint('pruned model saved in',save_path)# model = torch.load('model.pth') # load the pruned modelsave_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)# if init_map - current_map > args.max_map_drop:#     print("Pruning early stop")#     break# model.export(format='onnx')if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('--model', default='runs/detect/train/weights/last.pt', help='Pretrained pruning target model file')parser.add_argument('--cfg', default='default.yaml',help='Pruning config file.'' This file should have same format with ultralytics/yolo/cfg/default.yaml')parser.add_argument('--iterative-steps', default=4, type=int, help='Total pruning iteration step')parser.add_argument('--target-prune-rate', default=0.2, type=float, help='Target pruning rate')parser.add_argument('--max-map-drop', default=1, type=float, help='Allowed maximum map drop after fine-tuning')args = parser.parse_args()prune(args)
  1. 在代码的这些位置加上一些限制,不然它会经常的验证模型:
    在这里插入图片描述
    在这里插入图片描述

实验结果: 实验中,待续~

http://www.lryc.cn/news/188974.html

相关文章:

  • 功能基础篇6——系统接口,操作系统与解释器系统
  • 由于导线材质不同绕组直流电阻不平衡率超标
  • 选择智慧公厕解决方案,开创智慧城市公共厕所新时代
  • FFmpeg 基础模块:AVIO、AVDictionary 与 AVOption
  • 代数——第3章——向量空间
  • 2023年软考网工上半年下午真题
  • Flutter 直接调用so动态库,或调用C/C++源文件内函数
  • elasticsearch(ES)分布式搜索引擎03——(RestClient查询文档,ES旅游案例实战)
  • 198、RabbitMQ 的核心概念 及 工作机制概述; Exchange 类型 及 该类型对应的路由规则
  • 系统架构设计:18 论基于DSSA的软件架构设计与应用
  • Android原生实现控件outline方案(API28及以上)
  • ROS学习笔记(六)---服务通信机制
  • 常见的C/C++开源QP问题求解器
  • 前端axios发送请求,在请求头添加参数
  • CTF Misc(3)流量分析基础以及原理
  • Telink泰凌微TLSR8258蓝牙开发笔记(二)
  • vue3+elementPlus:el-tree复制粘贴数据功能,并且有弹窗组件
  • JTS:10 Crosses
  • MySQL中的SHOW FULL PROCESSLIST命令
  • VsCode 常见的配置、常用好用插件
  • 深度学习问答题(更新中)
  • JavaScript 笔记: 函数
  • 2023NOIP A层联测9-天竺葵
  • react antd table表格点击一行选中数据的方法
  • 【VUEX】最好用的传参方式--Vuex的详解
  • 【.net core】yisha框架 SQL SERVER数据库 反向递归查询部门(子查父)
  • java处理时间-去除节假日以及双休日
  • 快讯|Tubi 有 Rabbit AI 啦
  • Zookeeper从入门到精通
  • 10.11作业