当前位置: 首页 > news >正文

深度学习:nn.Linear

nn.Linear 是 PyTorch 中的一个线性层(全连接层),用于将输入张量从一个维度空间映射到另一个维度空间。具体来说,nn.Linear 执行以下操作:

output=input×weightT+bias

其中:
input 是输入张量。
weight 是权重矩阵。
bias 是偏置项(如果 bias=True)。

  • 具体作用:
    输入维度:
    假设键(key)的维度为 key_size,即每个键是一个形状为 (key_size,) 的向量。
    输出维度:
    通过 nn.Linear(key_size, num_hiddens),键被映射到一个新的维度空间,即每个键被转换为一个形状为 (num_hiddens,) 的向量。
    权重矩阵:
    nn.Linear 会自动创建一个形状为 (key_size, num_hiddens) 的权重矩阵 W_k。
    这个权重矩阵将在训练过程中通过反向传播进行优化,以学习如何将键从 key_size 维度映射到 num_hiddens 维度。

  • 示例

     - import torch
    import torch.nn as nn# 假设 key_size = 64, num_hiddens = 128
    key_size = 64
    num_hiddens = 128# 定义线性层 W_k
    W_k = nn.Linear(key_size, num_hiddens, bias=False)# 假设 K 的形状为 (batch_size, sequence_length, key_size)
    batch_size = 2
    sequence_length = 5
    K = torch.randn(batch_size, sequence_length, key_size)# 应用线性变换
    K_transformed = W_k(K)print(K_transformed.shape)
    

    输出为torch.Size([2, 5, 128])
    解释:
    输入:键张量 K 的形状为 (2, 5, 64),表示批量大小为 2,序列长度为 5,每个键的维度为 64。
    输出:经过线性变换后,K_transformed 的形状为 (2, 5, 128),表示每个键被映射到了 128 维的隐藏层空间。

http://www.lryc.cn/news/483281.html

相关文章:

  • 大数据新视界 -- 大数据大厂之 Impala 性能提升:高级执行计划优化实战案例(下)(18/30)
  • 常用的Anaconda Prompt命令行指令
  • 如何低成本、零代码开发、5分钟内打造一个企业AI智能客服?
  • 全网最全最新最细的MYSQL5.7下载安装图文教程
  • NoSQL数据库与关系型数据库的主要区别
  • ubuntu24.04安装matlab失败
  • Oracle 11g rac 集群节点的修复过程
  • c++:string(一)
  • github和Visual Studio
  • django框架-settings.py文件的配置说明
  • 【C语言】缺陷管理流程
  • 基于深度学习的猫狗识别
  • java组件安全
  • 【MongoDB】MongoDB的核心-索引原理及索引优化、及查询聚合优化实战案例(超详细)
  • qt QProcess详解
  • 软件测试面试2024最新热点问题
  • 10款录屏工具推荐,聊聊我的使用心得!!!!
  • VMware+Ubuntu+finalshell连接
  • autodl+modelscope推理stable-diffusion-3.5-large
  • 深度学习之 LSTM
  • LeetCode 3242.设计相邻元素求和服务:哈希表
  • 【AliCloud】ack + ack-secret-manager + kms 敏感数据安全存储
  • 探索JavaScript的强大功能:从基础到高级应用
  • 新增支持Elasticsearch数据源,支持自定义在线地图风格,DataEase开源BI工具v2.10.2 LTS发布
  • Spark的容错机制
  • YOLOv8改进 | 利用YOLOv8进行视频划定区域目标统计计数
  • 基于yolov8、yolov5的番茄成熟度检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • wafw00f源码详细解析
  • 什么是crm?3000字详细解析
  • WEB3.0介绍