当前位置: 首页 > news >正文

Pytorch中的nn.Embedding()

模块的输入是一个索引列表,输出是相应的词嵌入。

Embedding.weight(Tensor)–形状模块(num_embeddings,Embedding_dim)的可学习权重,初始化自(0,1)。
也就是说,pytorch的nn.Embedding()是可以自动学习每个词向量对应的w权重的。

import torch
import torch.nn as nn
embedding = nn.Embedding(9, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1,2,4,5,6,7,8,1,1,1,6,7,5],[4,3,2,1,6,7,8,1,1,1,6,7,5]])
#这里的input可以里的数字可以表示为embedding的索引.索引数据的shape是没有限制的,但是input中的数值不能超过nn.Embedding(9,3)中的9的.
a = embedding(input)
print(a)

http://www.lryc.cn/news/333170.html

相关文章:

  • WebSocketServer后端配置,精简版
  • Python程序设计 多重循环(二)
  • 前端面试题--CSS系列(一)
  • VSCode好用插件
  • Vue3:对ref、reactive的一个性能优化API
  • Python 用pygame简简单单实现一个打砖块
  • 软考113-上午题-【计算机网络】-IPv6、无线网络、Windows命令
  • 深入浅出 -- 系统架构之负载均衡Nginx资源压缩
  • 基于jsp+Spring boot+mybatis的图书管理系统设计和实现
  • Pytorch转onnx
  • 苍穹外卖——项目搭建
  • 云原生架构(微服务、容器云、DevOps、不可变基础设施、声明式API、Serverless、Service Mesh)
  • 基于重写ribbon负载实现灰度发布
  • 无端科技一面(生死狙击项目组 战斗客户端 40min)
  • idea开发 java web 高校学籍管理系统bootstrap框架web结构java编程计算机网页
  • linux之文件系统、inode和动静态库制作和发布
  • C++IO类,输入输出缓冲区,流状态
  • 机器学习笔记 - 文字转语音技术路线简述以及相关工具不完全清单
  • 阿里云4核8G服务器ECS通用算力型u1实例优惠价格
  • Jetson nano部署Yolov8 安装Archiconda3+创建pytorch环境(详细教程+错误解决)
  • Node.JS多线程PromisePool之promise-pool库实现
  • 【C++】红黑树讲解及实现
  • security如何不拦截websocket
  • Unity类银河恶魔城学习记录12-3 p125 Limit Inventory Slots源代码
  • 【智能排班系统】雪花算法生成分布式ID
  • sass中的导入与部分导入
  • 工业组态 物联网组态 组态编辑器 web组态 组态插件 编辑器
  • git可视化工具
  • 基于单片机电子密码锁系统设计
  • 点云从入门到精通技术详解100篇-基于点云与图像纹理的 道路识别(续)