当前位置: 首页 > news >正文

PyTorch中加载模型权重 A匹配B|A不匹配B

在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
在这里插入图片描述

未修改网络,A与B一致

很简单,直接.load_state_dict()

net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))

修改了网络,A与B不一致

[pytorch官方文档](Search — PyTorch master documentation):

load_state_dict(state_dict, strict=True)

将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的

load_state_dict()函数有两个返回值:

missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表

方法一:

将strict改为false,加载键值相同的部分。

model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:

ANet = torch.load('ANet.pt')  # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()  
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

方法二:

1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型

http://www.lryc.cn/news/110940.html

相关文章:

  • @FeignClient指定多个url实现负载均衡
  • vue diff 双端比较算法
  • 初识React: 基础(概念 特点 高效原因 虚拟DOM JSX语法 组件)
  • 自监督去噪:Neighbor2Neighbor原理分析与总结
  • 简单工厂模式(Simple Factory)
  • Agent:OpenAI的下一步,亚马逊云科技站在第5层
  • JMeter 4.x 简单使用
  • 深入NLTK:Python自然语言处理库高级教程
  • React 用来解析html 标签的方法
  • 【C++】做一个飞机空战小游戏(五)——getch()控制两个飞机图标移动(控制光标位置)
  • Flask 是什么?Flask框架详解及实践指南
  • C. Mark and His Unfinished Essay - 思维
  • Java的变量与常量
  • C# Blazor 学习笔记(6):热重置问题解决
  • 一百四十六、Xmanager——Xmanager5连接Xshell7并控制服务器桌面
  • 用Rust实现23种设计模式之 模板方法模式
  • python与深度学习(十三):CNN和IKUN模型
  • 题目:2283.判断一个数的数字计数是否等于数位的值
  • 任务14、无缝衔接,MidJourney瓷砖(Tile)参数制作精良贴图
  • 【uniapp APP如何优化】
  • uni-app——下拉框多选
  • 从excel中提取嵌入式图片的解决方法
  • python socket 网络编程的基本功
  • 【element-ui】form表单初始化页面如何取消自动校验rules
  • git 公钥密钥 生成与查看
  • 数据标注对新零售的意义及人工智能在新零售领域的应用?
  • 命令模式-请求发送者与接收者解耦
  • 【雕爷学编程】Arduino动手做(186)---WeMos ESP32开发板
  • 3、JSON数据的处理
  • 8月5日上课内容 nginx的优化和防盗链