当前位置: 首页 > article >正文

解决size mismatch for embedding.embed_dict.userid.weight

文章目录

  • 一、问题描述
  • 二、解决方法
  • 三、其他问题
  • Reference

一、问题描述

导入之前训练好的模型权重后使用模型预测时如题报错size mismatch for embedding.embed_dict.userid.weight

state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

二、解决方法

是因为导入的模型权重(之前训练好、保存的)的维度和当前定义的model的权重维度不同,所以我选择修改下当前定义的model,即将自己返回如下beat_sparse_features等的dataloader,其读取的数据换成之前模型训练的数据,使得模型定义后的model的模型权重和导入的权重一致。

model = DeepFM(deep_features=beat_dense_features + beat_sparse_features,fm_features=beat_sparse_features,mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"},
)

当然如果根据大家的实际情况改动,如很多时候实例化模型时改变实参即可。

三、其他问题

可能还有其他情况也会报这个错,如导入预训练模型进行微调,首先加载预训练模型权重:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
model.load_state_dict(pretrained_dict)
model.fc = torch.nn.Linear(512, 5) # 512为原始fc的数目,5是自己任务的分类数

由于分类类别不一致,报错size mismatch for fc.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([x]).,这里可以选择不加载fc层:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)} # 将'fc'这一层的权重选择不加载即可。
model_dict.update(pretrained_dict) # 更新权重
model.load_state_dict(model_dict)

可能还有其他情况,如NLP词表维度不一致等等,后面遇到再更新该帖。如有不对之处,恳请大佬们指正!

Reference

[1] 解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias
[2] torch 封装文本数据预处理、训练、评估、预测过程
[3] 关于Pytorch加载模型参数的避坑指南

http://www.lryc.cn/news/2420921.html

相关文章:

  • 单片机——LCD1602
  • 移动测试之-流量测试方案
  • Visual Studio 2008 试用版评估期已结束的解决方法
  • 一步步优化JVM七:其他
  • 无法启动计算机上的服务msdtc,MSDTC服务无法启动解决方法
  • 分享116个ASP搜索链接源码,总有一款适合您
  • Hello C++
  • 纳什均衡定义、举例、分类
  • 开启游戏别样体验:《下一站江湖2》风灵月影六十项修改器使用手册
  • ubuntu9.10 软件推荐
  • Oracle DB Time 解读
  • 收集一些有质感、有内涵的网站 (转载)
  • 实时监控系统介绍
  • MapInfo是一种流行的地理信息系统(GIS)软件,它提供了丰富的功能和工具,用于处理、分析和可视化地理空间数据
  • CAN总线学习笔记 | CAN基础知识介绍
  • 2024年最全在线查询默认密码网站--分享_hawel-lutuo默认密码(1),分析网络安全未来几年的发展前景
  • java计算机毕业设计电商网站在线客服(附源码+springboot+开题+论文+部署)
  • 递归和迭代_深究递归和迭代的区别、优缺点及实例对比
  • 网络层 IPV4报文格式
  • 中国网站广告联盟大集合
  • 5.秒杀模块-基于redis缓存商品秒杀信息
  • ‘真三国无双5’完美存档修改
  • 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)
  • gitgitlab 修改本地分支名称和远程分支名称
  • 初探Spark-使用大数据分析2000W行数据
  • 博客屋网址导航自适应主题php源码
  • 驱动python_光驱驱动下载_万能光驱驱动(万能DVD光驱CD光驱驱动) 2018 官方版_极速下载站...
  • MFC框架机制详解
  • 【C语言经典例题100解答】
  • web自动化测试_web自动化测试工具和框架有哪些?