当前位置: 首页 > news >正文

Torch同时训练多个模型

20230302

引言

在进行具体的研究时,利用Torch进行编程,考虑到是不是能够同时训练两个模型呢?!而且利用其中一个模型的输出来辅助另外一个模型进行学习。这一点,在我看来应该是很简单的,例如GAN网络同时训练这个生成器和判别器。但是实际操作中,却发现一直报错。

之前的时候利用Keras进行AAE(对抗自编码器)的编程的时候,他是把其中一个模型的参数trainable(应该是这个名字)定义为了false

分析

在帖子[1]中,基本上完整的说明了我的问题,首先是实际往后推梯度直接报错,如下图。然后提议把这个retain_graph设置好;

在这里插入图片描述

设置了之后呢,依然是会报错:
在这里插入图片描述

这个报错过程,跟我写的程序是一模一样的。另外一个帖子[2],两者给出的解答方式都是添加detach()。实际上,我理解哈,(之前最开始的时候看过计算图的相关内容,后来有点忘了),就是在第一个损失函数推完之后,这部分他的梯度已经没有了,那么再使用第一个模型中的输出变量与第二个模型进行计算的时候,这部分也会输出一部分梯度到这个第一个模型上,但是本质上,你已经不需要在进行计算了,而这个梯度可能还会遗留到后续,所以会出现这种报错。(通俗理解,可能内部细节更多)

而添加detach()之后,就是为了吧这个变量从计算图中取出来,但是不用计算梯度,见文章[3]。所以可以解决这个问题。如果这样话,其实retain_graph变量可以依然是false。具体可以看AAE这部分的代码

在这里插入图片描述

这部分核心在于最后部分计算的时候,encoded_img已经用过了,而且梯度也推完了,那么后面再次使用的时候,就需要加上detach()

参考

[1]How to train Two models simultaneously?
[2]Training multiple models at the same time
[3]pytorch .detach() .detach_() 和 .data用于切断反向传播
[4]PyTorch-GAN/implementations/aae/aae.py

http://www.lryc.cn/news/25923.html

相关文章:

  • LCR数字电桥软件下载安装教程
  • C++模板写法详解
  • 【备战面试】每日10道面试题打卡-Day2
  • “数字档案室测评”相关参考依据梳理
  • android 动态加载jar包
  • JAVA版B2B2C商城源码多商户入驻商城
  • 测试人员如何在测试环境数据库批量生成测试数据?方案分享
  • 【el】表单
  • 【Flutter入门到进阶】Flutter基础篇---布局
  • python海龟绘图
  • 【计算机网络】数据链路层
  • 使用groovy代码方式解开gradle配置文件神秘面纱
  • kafka入门到实战二(使用docker搭建kafka集群)
  • 【简化开发】lombok的使用、编译后的代码及源码
  • 在线就能用的主图设计素材,免费分享!
  • 【测绘程序设计】——计算卫星位置
  • 山东双软认证的基本条件
  • TPM 2.0实例探索3 —— LUKS磁盘加密(4)
  • Linux连接RDP远程服务工具集记录
  • 离散事件动态系统
  • 无线WiFi安全渗透与攻防(二)之打造专属字典
  • 拥抱 Spring 全新 OAuth 解决方案
  • 前端开发与vscode开发工具介绍
  • C++---最长上升子序列模型---友好城市(每日一道算法2023.3.2)
  • maven高级知识。
  • Python 之 Pandas 处理字符串和apply() 函数、applymap() 函数、map() 函数详解
  • 汇川AM402和上位机C#ModebusTcp通讯
  • 给你一个电商网站,你如何测试?功能测试及接口测试思路是什么?
  • Spring Boot 3.0系列【5】基础篇之应用配置文件
  • SQLyog图形化界面工具【超详细讲解】