当前位置: 首页 > news >正文

EMA训练微调

就是取前几个epoch的weight的平均值,可以缓解微调时的灾难性遗忘(因为新数据引导,模型权重逐渐,偏离训练时学到的数据分布,忘记之前学好的先验知识)
在这里插入图片描述

class EMA():def __init__(self, model, decay):self.model = modelself.decay = decay  # decay rateself.shadow = {}  # old weightself.backup = {}  # new weightdef register(self):  # deep copy weight for initfor name, param in self.model.named_parameters():if param.requires_grad:self.shadow[name] = param.data.clone()def update(self):  # ema:average weight for trainfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadownew_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()def apply_shadow(self):  # load old weight for eval beginfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadowself.backup[name] = param.dataparam.data = self.shadow[name]def restore(self):  # load new weight for eval endfor name, param in self.model.named_parameters():if param.requires_grad:assert name in self.backupparam.data = self.backup[name]self.backup = {}# 初始化
ema = EMA(model, 0.999)
ema.register()# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()ema.update()# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():ema.apply_shadow()# evaluateema.restore()
http://www.lryc.cn/news/248249.html

相关文章:

  • Kafka集群部署详细教程
  • 交叉编译
  • 数据结构与算法之递归: LeetCode 46. 全排列 (Typescript版)
  • SQL中 JOIN 的两种连接类型:内连接(自然连接、自连接、交叉连接)、外连接(左外连接、右外连接、全外连接)
  • 微信小程序记住密码,让登录解放双手
  • 国内划片机行业四大企业之博捷芯:技术驱动,领跑未来
  • 后端整合Swagger+Knife4j接口文档
  • k8s中批量处理Pod应用的Job和CronJob控制器介绍
  • UE5 范围内随机生成
  • 杂记 | 使用Docker安装并配置MongoDB以支持事务(单副本,并解决了证书文件错误的问题)
  • css三角,鼠标样式,溢出文字
  • 远程桌面访问MATLAB 2018B,提示License Manger Error -103,终极解决方案
  • Jmeter基础和概念
  • 【Linux 带宽限速】trickle,限制docker 上传速度
  • MindStudio学习记录三:推理应用开发 acl mindx sdk
  • 【RT-DETR改进】SIoU、GIoU、CIoU、DIoU、AlphaIoU等二十余种损失函数
  • 【Linux】EVIOCGBIT
  • 鸿蒙4.0开发笔记之ArkTS装饰器语法基础@Extend扩展组件样式与stateStyles多态样式(十一)
  • 5V摄像机镜头驱动IC GC6208,可用于摄像机,机器人等产品中可替代AN41908
  • PHP echo和print 语句
  • ThinkPHP6.1 多应用模式的一些事儿
  • redis-cluster集群模式
  • 带你用uniapp从零开发一个仿小米商场_10. 首页开发
  • 常使用的定时任务
  • 【人工智能Ⅰ】实验2:遗传算法
  • Hadoop集群升级(3.1.3 -> 3.2.4)
  • (一)基于高尔夫优化算法GOA求解无人机三维路径规划研究(MATLAB)
  • ESP32-Web-Server编程-建立第一个网页
  • csgo/steam游戏搬砖项目的五大认知误区
  • ASCII sorting