当前位置: 首页 > news >正文

tensorflow 模型计算中,预测错误;权重参数加载

tensorflow 模型计算中,预测错误;权重参数加载

tensorflow 模型计算主要代码(正确代码)

linear1_kernel_initializer = tf.constant_initializer(numpy.transpose(data["linear1.weight"]))
linear1_bias_initializer = tf.constant_initializer(numpy.transpose(data["linear1.bias"]))
linear1 = layers.Dense(units=400, activation=tf.nn.relu, kernel_initializer=linear1_kernel_initializer, use_bias=True, bias_initializer=linear1_bias_initializer, input_shape=(48,))
linear2_kernel_initializer = tf.constant_initializer(numpy.transpose(data["linear2.weight"]))
linear2_bias_initializer = tf.constant_initializer(numpy.transpose(data["linear2.bias"]))
linear2 = layers.Dense(units=400, activation=tf.nn.relu, kernel_initializer=linear2_kernel_initializer, use_bias=True, bias_initializer=linear2_bias_initializer)
linear3_kernel_initializer = tf.constant_initializer(numpy.transpose(data["linear3.weight"]))
linear3_bias_initializer = tf.constant_initializer(numpy.transpose(data["linear3.bias"]))
linear3 = layers.Dense(units=2, activation=None, kernel_initializer=linear3_kernel_initializer, use_bias=True, bias_initializer=linear3_bias_initializer)
model = tf.keras.Sequential([linear1, linear2, linear3])input = numpy.ones((2, 48), dtype=float)
predict = model.predict(input)
print(predict[0:100,:])

原本权重参数采用以下代码

linear1_kernel_initializer = tf.constant_initializer(data["linear1.weight"])
linear1_bias_initializer = tf.constant_initializer((data["linear1.bias"])

但模型预测值与Matlab计算值有误。后经过测试定位到 layers.Dense 此处,然后创建 layers.Dense时设置use_bias=False参数,不去考虑偏差参数。改变初始权重参数方式:

input_size = 2
units_p = 3
data = numpy.array([1, 1, 2, 2, 2, 3], dtype=float)
linear1_kernel_initializer = tf.constant_initializer(data)
linear1 = layers.Dense(units=units_p, activation=None, kernel_initializer=linear1_kernel_initializer, use_bias=False, input_shape=(input_size,))
#变化data
data = numpy.array([1, 2, 3, 1, 2, 3], dtype=float)
#或者
data = numpy.array([1, 2, 3, 1, 2, 3], dtype=float).reshape(3, 2)

通过这样的方式,才发现 linear1_kernel_initializer = tf.constant_initializer(data) 中的 data 有问题,通过对预测结果的分析,发现 tf.constant_initializer() 会将传递过来的数据拉成一维,再根据 units不同层 来变更数据矩阵大小,所以传入tf.constant_initializer()的数据只要总大小是对的就可以传入,而不需要shape一致。
所以,既然之前的数据预测结果有误,那就是数据排列有误,将 data 数据进行矩阵转置 再 传入到tf.constant_initializer() 函数中
问题成功解决。
同时我想说明的是,pytorchtorch.nn.LinearW x + btensorflowlayers.Densex W + b

tensorflow这种情况可以形象的表达为 流动的关系,input -> HL1 -> HL2 -> output(HL1为隐藏层1)

input 卷上 W1 + b1 => HL1结果
HL1结果 卷上 W2 + b2 => HL2结果
HL2结果 卷上 W3 + b3 => outpu

http://www.lryc.cn/news/125605.html

相关文章:

  • Jay17 2023.8.14日报 即 留校集训阶段性总结
  • 【C语言】小游戏-扫雷(清屏+递归展开+标记)
  • 云服务 Ubuntu 20.04 版本 使用 Nginx 部署静态网页
  • 无后效性
  • Kubernetes系列-删除deployment和pod
  • kotlin字符串方法
  • ubuntu篇---配置FTP服务,本机和docker安装
  • SpringBoot中properties、yml、yaml的优先级
  • SHELL 基础 SHELL注释 及 执行SHELL脚本的四种方法
  • 【Spring】深入探索 Spring AOP:概念、使用与实现原理解析
  • LocalDate介绍和使用
  • 三、使用注解形式开发 Spring MVC程序
  • 【Go】常见的四个内存泄漏问题
  • 【LeetCode-简单】剑指 Offer 29. 顺时针打印矩阵(详解)
  • TOMCAT基础
  • 自动化集装箱码头建设指南
  • 为什么要用redis
  • QT qmake解析
  • 【TypeScript】this指向,this内置组件
  • MySQL 深度分页优化
  • 如何在CSS中水平居中一个元素?
  • 生信豆芽菜-ESTIMATE预测免疫评分
  • 分享一颗能用在TYPE-C接口取电协议芯片LDR6328Q,方便好用
  • 【java】Java与SQLite3数据库类型之间对应关系
  • ELK常见部署架构以及出现的问题及解决方案
  • windows使用vscode配置java开发环境
  • centos系统kubeadm安装K8S_v1.27.x容器使用docker(K8S_v1.24版本以后依然使用docker容器管理)
  • 如何使用索引加速 SQL 查询 [Python 版]
  • Oracle 开发篇+Java通过DRCP访问Oracle数据库
  • 在安装 ONLYOFFICE 协作空间社区版时如何使用额外脚本参数