当前位置: 首页 > news >正文

pytorch nn.Embedding 用法和原理

nn.Embedding 是 PyTorch 中的一个模块,用于将离散的输入(通常是词或子词的索引)映射到连续的向量空间。它在自然语言处理和其他需要处理离散输入的任务中非常常用。以下是 nn.Embedding 的用法和原理。

用法

初始化 nn.Embedding
nn.Embedding 的初始化需要两个主要参数:

  1. num_embeddings:字典的大小,即输入的最大索引值 + 1。
  2. embedding_dim:每个嵌入向量的维度。

此外,还有一些可选参数,如 padding_idx、max_norm、norm_type、scale_grad_by_freq 和 sparse。

import torch
import torch.nn as nn# 创建一个 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)

输入和输出
nn.Embedding 的输入是一个包含索引的长整型张量,输出是对应的嵌入向量。

# 示例输入
input_indices = torch.LongTensor([1, 2, 3, 4])
output_vectors = embedding_layer(input_indices)
print(output_vectors)

示例代码
以下是一个完整的示例代码,展示了如何使用 nn.Embedding 层:

import torch
import torch.nn as nn# 创建 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)# 示例输入
input_indices = torch.LongTensor([1, 2, 3, 4])# 获取嵌入向量
output_vectors = embedding_layer(input_indices)
print("Input indices:", input_indices)
print("Output vectors:", output_vectors)

原理

nn.Embedding 层的本质是一个查找表,它将输入的每个索引映射到一个固定大小的向量。这个映射表在初始化时会随机生成,然后在训练过程中通过反向传播进行优化。
主要步骤

  1. 初始化:在初始化时,nn.Embedding 会创建一个大小为 (num_embeddings, embedding_dim)的权重矩阵。这些权重是嵌入层的参数,会在训练过程中更新。
  2. 前向传播:在前向传播过程中,nn.Embedding 层会将输入的索引映射到权重矩阵的相应行,从而得到对应的嵌入向量。
  3. 反向传播:在训练过程中,嵌入层的权重矩阵会根据损失函数的梯度进行更新。这使得嵌入向量能够捕捉到输入的语义信息。

参数解释

  • padding_idx:如果指定了 padding_idx,则该索引的嵌入向量在训练过程中不会被更新。通常用于处理填充(padding)标记。
  • max_norm:如果指定了 max_norm,则会对每个嵌入向量的范数进行约束,使其不超过 max_norm。
  • norm_type:用于指定范数的类型,默认是2范数。
  • scale_grad_by_freq:如果设置为 True,则会根据输入中每个词的频率缩放梯度。
  • sparse:如果设置为 True,则使用稀疏梯度更新,适用于大词汇表的情况。

原理解释

  1. 查找表:nn.Embedding 的核心是一个查找表,其大小为 (num_embeddings,embedding_dim),每一行代表一个词或索引的嵌入向量。
  2. 前向传播:在前向传播中,输入的索引被用来查找嵌入向量。假设输入是 [1, 2, 3],则输出是权重矩阵中第1、第2和第3行的向量。
  3. 反向传播:在反向传播中,嵌入向量的梯度会根据损失函数进行计算,并用于更新权重矩阵。

通过这种方式,嵌入向量能够在训练过程中不断调整,使得相似的输入索引(例如语义相似的词)在向量空间中更接近,从而捕捉到输入的语义信息。

总结
nn.Embedding 是 PyTorch 中处理离散输入的一个非常强大且常用的工具。通过将离散索引映射到连续向量空间,并在训练过程中优化这些向量,nn.Embedding 能够捕捉到输入的丰富语义信息。这对于自然语言处理等任务来说是非常重要的。

http://www.lryc.cn/news/388522.html

相关文章:

  • Python中常用的有7种值(数据)的类型及type()语句的用法
  • 某配送平台未授权访问和弱口令(附赠nuclei默认密码验证脚本)
  • 01.总览
  • Linux换源
  • 【高考志愿】 化学工程与技术
  • 2024上半年网络与数据安全法规政策、国标、报告合集
  • 基于SpringBoot扶农助农政策管理系统设计和实现(源码+LW+调试文档+讲解等)
  • 淘宝商铺电话怎么获取?使用爬虫工具采集
  • ModStart:开源免费的PHP企业网站开发建设管理系统
  • npm安装依赖报错——npm ERR gyp verb cli的解决方法
  • 公网环境使用Potplayer远程访问家中群晖NAS搭建的WebDAV听歌看电影
  • Forecasting from LiDAR via Future Object Detection
  • 【unity笔记】五、UI面板TextMeshPro 添加中文字体
  • 如何在Windows 11上设置默认麦克风和相机?这里有详细步骤
  • Flutter循序渐进==>数据结构(列表、映射和集合)和错误处理
  • 泛微E9开发 限制明细表列的值重复
  • magicapi导出excel
  • 【秋招突围】2024届秋招笔试-科大讯飞笔试题-03-三语言题解(Java/Cpp/Python)
  • springboot是否可以代替spring
  • 基于SpringBoot的CSGO赛事管理系统
  • 使用 Selenium 实现自动化分页处理与信息提取
  • 现代信息检索笔记(二)——布尔检索
  • 使用Python实现学生管理系统
  • 【嵌入式DIY实例】- LCD ST7735显示DHT11传感器数据
  • 基于Tools体验NLP编程的魅力
  • 强化学习-3深度学习基础
  • SOC模块LoRa-STM32WLE5有哪些值得关注
  • CSS中的display属性:布局控制的关键
  • 【Spring Boot AOP通知顺序】
  • k8s是什么