当前位置: 首页 > news >正文

yolov8蒸馏(附代码-免费)

首先蒸馏是什么?

模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。

为什么需要蒸馏?

  1. 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
  2. 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
  3. 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。

目录

一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

二.蒸馏步骤

(1) 训练教师模型

(2) 训练学生模型

(3) 蒸馏训练

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:


一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

本文代码已经上传到GitHub,链接:yolov8_蒸馏

使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。

二.蒸馏步骤

(1) 训练教师模型

打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model = YOLO("yolov8s.pt")model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(2) 训练学生模型

打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(3) 蒸馏训练

打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。

import os
from ultralytics import YOLO
import torchos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_t = YOLO('runs/detect/yolov8s/weights/best.pt')  # the teacher modelmodel_s = YOLO('runs/detect/prune/weights/best.pt')  # the student model"""Attributes:Distillation: the distillation modelloss_type: mgd, cwdamp: Automatic Mixed Precision"""model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,batch=20, device=0, workers=0, lr0=0.001)if __name__ == '__main__':main()

现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py

在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:

    def forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):# change ---if self.distiller == 'cwd':s = self.align_module[idx](s)s = self.norm[idx](s)else:s = self.norm1[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss

调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的

        channels_s = [256, 480, 256, 64, 143, 229][-le:]channels_t = [256, 512, 256, 128, 256, 512][-le:]

这边总共有六个,刚好对应模型的六个层的通道数。

替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75yolo.info()
model = yolo.model
ws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())# keepws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)# print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f):  # C2f as a top convm1 = m1.cv2if not isinstance(m2, list):  # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i + 1])detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True#yolo.val(workers=0)  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调
ckpt = {'epoch': -1,'best_fitness': None,'model': yolo.ckpt['ema'],'ema': None,'updates': None,'optimizer': None,'train_args': yolo.ckpt["train_args"],  # save as dict'date': None,'version': '8.0.142'}torch.save(yolo.ckpt, res_dir)

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------

http://www.lryc.cn/news/414332.html

相关文章:

  • Flink-StarRocks详解:第五部分查询数据湖(第55天)
  • 【MySQL】常用数据类型
  • 创建第一个rust tauri项目
  • 【课程总结】day19(中):Transformer架构及注意力机制了解
  • 4.4 标准正交基和格拉姆-施密特正交化
  • spring事务的8种失效的场景,7种传播行为
  • 进程的虚拟内存地址(C++程序的内存分区)
  • 英特尔移除超线程与AMD多线程性能对比
  • 定期自动巡检,及时发现机房运维管理中的潜在问题
  • 八股文(一)
  • 灵茶八题 - 子数组 ^w^
  • git clone private repo
  • vue3+ts+pinia+vant-项目搭建
  • 自动化测试概念篇
  • Mojo值的生命周期(Life of a value)详解
  • java对接kimi详细说明,附完整项目
  • 鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频
  • django集成pytest进行自动化单元测试实战
  • 48天笔试训练错题——day40
  • LabVIEW在DCS中的优势
  • 英特尔:从硅谷创业到全球科技巨头
  • 生物计算与纳米技术:交汇前沿的科学领域
  • C#中栈和队列
  • 技战法丨攻防演练防御——纵深、联动、诱捕(可搬运、可cv)
  • 1、 window平台opencv下载编译, 基于cmake和QT工具链
  • C++20三向比较运算符详解
  • 监听机制与耗电量
  • C++ //练习 16.29 修改你的Blob类,用你自己的shared_ptr代替标准库中的版本。
  • 【Mode Management】CanNm处于PBS状态下接收到一帧诊断报文DCM会响应吗
  • 【C++】模版:范式编程、函数模板、类模板