当前位置: 首页 > news >正文

【Tensorflow2.0】tensorflow中的Dense函数解析

目录

  • 1 作用
  • 2 例子
  • 3 与torch.nn.Linear的区别
  • 4 参考文献

1 作用

注意此处Tensorflow版本是2.0+。
由于本人是Pytorch用户,对Tensorflow不是很熟悉,在读到用tf写的代码时就很是麻烦。如图所示,遇到了如下代码:

h = Dense(units=adj_dim, activation=None)(dec_in)

  Dense层就是全连接层,对于层方式的初始化的时候,layers.Dense(units,activation)函数一般只需要指定输出节点数Units和激活函数类型即可。输入节点数将根据第一次运算时输入的shape确定,同时输入、输出节点自动创建并初始化权值w和偏置向量b。

下面是Dense的接口

Dense(units,activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

units, 代表该层的输出维度
activation=None, 激活函数.但是默认 liner
use_bias=True, 是否使用b 直线 y=ax+b 中的 b

此处没有写 iuput 的情况, 通常会有两种写法:

1 : Dense(units,input_shape())2 : Dense(units)(x) #这里的 x 是以张量.

Dense(n)(x):=ReLU(Wx+b)Dense \ (n) \ (x):=ReLU(Wx+b)Dense (n) (x):=ReLU(Wx+b)
W 是权重函数, Dense() 会随机给 W 一个初始值。所以这里跟Pytorch的nn.linear()一样。

2 例子

# 使用第一种方法进行初始化
# 作为 Sequential 模型的第一层,需要指定输入维度。可以为 input_shape=(16,) 或者 input_dim=16,这两者是等价的。
model = Sequential()
model.add(Dense(32, input_shape=(16,)))
# 现在模型就会以尺寸为 (*, 16) 的数组作为输入,
# 其输出数组的尺寸为 (*, 32)# 在第一层之后,就不再需要指定输入的尺寸了:
model.add(Dense(32))

3 与torch.nn.Linear的区别

# Pytorch实现
trd = torch.nn.Linear(in_features = 3, out_features = 30)
y = trd(torch.ones(5, 3))
print(y.size())
# torch.Size([5, 30])# Tensorflow实现
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30, input_shape=(5,), activation=None))
————————————————————————————————————
tfd = tf.keras.layers.Dense(30, input_shape=(3,), activation=None)
x = tfd(tf.ones(shape=(5, 3)))
print(x.shape)
# (5, 30)
上面Tensorflow的实现方式相同,但是我存在疑惑

4 参考文献

[1]dense层、激活函数、输出层设计
[2]Dense(units, activation=None,)初步
[3]深入理解 keras 中 Dense 层参数
[4]tensorflow - Tensorflow 的 tf.keras.layers.Dense 和 PyTorch 的 torch.nn.Linear 的区别?

http://www.lryc.cn/news/1247.html

相关文章:

  • PyTorch学习笔记:data.RandomSampler——数据随机采样
  • 设计模式(七)----创建型模式之建造者模式
  • DCGAN
  • 【速通版】吴恩达机器学习笔记Part3
  • 【leetcode】跳跃游戏
  • 论文投稿指南——中文核心期刊推荐(冶金工业 2)
  • 【GPLT 二阶题目集】L2-044 大众情人
  • SpringBoot整合(二)MyBatisPlus技术详解
  • 导入importk8s集群,添加node节点,rancher agent,Rancher Agent设置选项
  • C++11--右值引用与移动语义
  • Python SQLAlchemy入门教程
  • 你是真的“C”——操作符详解【下篇】+整形提升+算术转换
  • 文本匹配SimCSE模型代码详解以及训练自己的中文数据集
  • Biotin-PEG-FITC 生物素聚乙二醇荧光素;FITC-PEG-Biotin 科研用生物试剂
  • FISCO BCOS 搭建区块链,在SpringBoot中调用合约
  • 面试官:int和Integer有什么区别?
  • MFC常用技巧
  • C++ —— 多态
  • java agent设计开发概要
  • node.js笔记-模块化(commonJS规范),包与npm(Node Package Manager)
  • Linux 磁盘坏块修复处理(错误:read error: Input/output error)
  • API 面试四连杀:接口如何设计?安全如何保证?签名如何实现?防重如何实现?
  • 操作系统题目收录(六)
  • 2023年十款开源测试开发工具推荐!
  • MySQL慢查询分析和性能优化
  • C++学习笔记(四)
  • 【4】深度学习之Pytorch——如何使用张量处理时间序列数据集(共享自行车数据集)
  • mulesoft MCIA 破釜沉舟备考 2023.02.10.01
  • 干货 | PCB拼板,那几条很讲究的规则!
  • 笔试题-2023-思远半导体-数字IC设计【纯净题目版】