当前位置: 首页 > news >正文

把组合损失中的权重设置为可学习参数

目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。

目录

1、构建组合损失类

2、调用组合损失类

3、为其构建优化器

4、梯度归零

5、跟新优化器参数

6、结果展示


1、构建组合损失类

每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。

# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):def __init__(self):super(CombinedLoss, self).__init__()# 定义损失函数权重作为可训练参数self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True))  # 对抗损失的权重,初始值为0.2 self.w_con = nn.Parameter(torch.ones(1, requires_grad=True))  # 内容感知损失的权重,初始值为0.2self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True))  # 均方误差损失的权重,初始值为0.2self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True))  # 随机结构相似性损失的权重,初始值为0.2self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True))  # 边缘引导损失的权重,初始值为0.2def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui

2、调用组合损失类

在计算组合损失之前,需要初始化类对象。

combinedloss = Loss.CombinedLoss()unet_loss = self.combinedloss(loss_adv = unet_gan_loss, loss_con = gen_content_loss, loss_mse = unet_criterion, loss_s3im = s3im_loss, loss_gui = guid_loss)

3、为其构建优化器

最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。

self.lr_weight_optimizer = optim.Adam(self.combinedloss.parameters(),lr = 1e-4,betas=(0.9, 0.999))

4、梯度归零

在每次计算总损失之前,需要把每个优化器的梯度归零

self.lr_weight_optimizer.zero_grad()

5、跟新优化器参数

在总损失反向传播之后,需要对优化器的参数进行更新

self.lr_weight_optimizer.step()

6、结果展示

每个权重都会自动更新。 

http://www.lryc.cn/news/328155.html

相关文章:

  • 用Bat启动jar程序
  • 网站维护页404源码
  • jmeter链路压测
  • 香港服务器怎么看是CN2 GT线路还是CN2 GIA线路?
  • CrossOver软件2024免费 最新版本详细介绍 CrossOver软件好用吗 Mac电脑玩Windows游戏
  • harbor api v2.0
  • Vue 表单数据双向绑定 v-mode
  • tab切换组件,可横向自适应滑动
  • 设计模式---单例模式
  • HarmonyOS 应用开发之启动/停止本地PageAbility
  • BaseDao封装增删改查
  • Redis入门到实战-第十三弹
  • 深度学习InputStreamReader类
  • 2023年后端面试总结
  • axios实现前后端通信报错Unsupported Media
  • 网络套接字补充——TCP网络编程
  • Nginx-记
  • JS面试题:call,apply,bind区别
  • Charles抓包配置代理手机连接
  • NA555、NE555、SA555和SE555系列精密定时器
  • 黑马鸿蒙笔记2
  • 微信小程序uniapp+vue3+ts+pinia的环境搭建
  • MongoDB聚合运算符:$let
  • HarmonyOS像素转换-如何使用像素单位设置组件的尺寸。
  • 【前端面试3+1】05v-if和v-show的区别、v-if和v-for能同时使用吗、Vuex是什么?【合并两个有序链表】
  • Unity WebRequest 变得简单
  • vue 窗口内容滚动到底部
  • 代码随想录算法训练营Day38|LC509 斐波那契数列LC70 爬楼梯LC746 使用最小花费爬楼梯
  • Qt5.14.2 大神的拖放艺术,优雅而强大的交互体验
  • python3将exe 转支持库错误 AssertionError: None does not smell like code