当前位置: 首页 > news >正文

【知识蒸馏】deeplabv3 logit-based 知识蒸馏实战,对剪枝的模型进行蒸馏训练

本文将对【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝 文章中剪枝的模型进行蒸馏训练

一、逻辑蒸馏步骤

  • 加载教师模型
  • 定义蒸馏loss
  • 计算蒸馏loss
  • 正常训练

二、代码

1、加载教师模型

教师模型使用未进行剪枝,并且已经训练好的原始模型。

teacher_model = torch.load('./logs/before_prune.pth', map_location=device)

2、定义蒸馏loss

分割和分类的loss,都是用的softmax。

import torch.nn.functional as F
import torch.nn as nn
# 蒸馏温度
Tempature = 2
def KD_loss(teacher_pred, student_pred):t_p = F.softmax(teacher_pred / Tempature, dim=1)s_p = F.log_softmax(student_pred / Tempature, dim=1)return nn.KLDivLoss(reduction='mean')(s_p, t_p) * (Tempature ** 2)

3、 计算蒸馏loss

teacher_outputs = t_model(imgs)
# 蒸馏loss
soft_loss = KD_loss(teacher_outputs, outputs)
# 总loss = 蒸馏loss*alpha + 原学生模型loss*(1-alpha)
alpha = 0.9
all_loss = loss * (1 - alpha) + soft_loss * alpha

4、正常训练

all_loss.backward()

用剪枝前训练好的模型对剪枝后模型进行蒸馏训练,训练后测试效果如下:
在这里插入图片描述

http://www.lryc.cn/news/352792.html

相关文章:

  • 02.爬虫---HTTP基本原理
  • HTTP响应的基本概念
  • 链栈的存储
  • 常见网络协议及端口号
  • 几张自己绘制的UML图
  • [读论文]精读Self-Attentive Sequential Recommendation
  • HTML静态网页成品作业(HTML+CSS)——动漫海绵宝宝介绍网页(5个页面)
  • 开放式耳机2024超值推荐!教你如何选择蓝牙耳机!
  • 程序员搞副业的障碍有那些?
  • windows7的ie11降级到ie8
  • 楼房vr安全逃生模拟体验让你在虚拟环境中亲身体验火灾的紧迫与危险
  • rust 学习--所有权
  • 关于Git 的基本概念和使用方式
  • 《计算机网络微课堂》1-6 计算机体系结构
  • 大模型的灵魂解读:Anthropic AI的Claude3 Sonnet可解释性研究
  • 大模型框架:vLLM
  • SQL 使用心得【持续更新】
  • 基于Spring Boot的高校图书馆管理系统
  • python(4) : pip安装使用国内源
  • 让写书人勇敢穿越纸海的迷雾
  • ROS2学习——节点话题通信(2)
  • 【Spring Boot】深度复盘在开发搜索引擎项目中重难点的整理,以及遇到的困难和总结
  • 配置docker阿里云镜像地址
  • ICML 2024 Mamba 论文总结
  • Sass详解
  • 如何实现一个高效的排序算法?
  • Linux--10---安装JDK、MySQL
  • 【大数据】MapReduce JAVA API编程实践及适用场景介绍
  • 图像分类和文本分类(传统机器学习和深度学习)
  • 基于SpringBoot和Hutool工具包实现的验证码案例