当前位置: 首页 > news >正文

Pytorch cat()与stack()函数详解

torch.cat()

cat为concatenate的缩写,意思为拼接,torch.cat()函数一般是用于张量拼接使用的

cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor:

可以看到cat()函数的参数,常用的参数为,第一个参数:可以选择元组或者列表,内部包含需要拼接的张量,需要按照顺序排列,第二个参数为dim,用于指定需要拼接的维度

import torch
import numpy as npdata1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])print(data1)
print(data2)
print("-" * 20)print(torch.cat([data1, data2], dim=0))
print(torch.cat([data1, data2], dim=1))
print(torch.cat([data1, data2], dim=2))
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9]]])
# tensor([[[3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# --------------------
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9]],
# 
#         [[3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8],
#          [3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9],
#          [2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# tensor([[[9, 4, 0, 0, 3, 2, 6, 3],
#          [3, 3, 7, 6, 8, 3, 1, 1],
#          [6, 1, 0, 8, 0, 9, 2, 5]],
# 
#         [[9, 1, 1, 2, 2, 6, 7, 5],
#          [1, 0, 6, 4, 9, 1, 0, 1],
#          [7, 9, 3, 9, 0, 6, 4, 4]]])

上述代码演示了拼接维度为0,1,2的时候的结果,可以看出cat()并不会影响张量的维度,如上述的三维张量拼接,若dim为0则按块(后两位张量组成的二维张量)进行拼接,若dim为1则按行拼接,若dim为2则按列拼接

torch.stack()

stack为堆叠、栈的意思

stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: 

可以看到stack()和cat()的用法几乎一致,都是用于堆叠张量组成的列表或元组,以及堆叠的维度dim

import torch
import numpy as npdata1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])print(data1)
print(data2)
print("-" * 20)data3 = torch.stack([data1, data2], dim=0)
data4 = torch.stack([data1, data2], dim=1)
data5 = torch.stack([data1, data2], dim=2)
data6 = torch.stack([data1, data2], dim=3)
print(data3.shape)
print(data3)
print(data4.shape)
print(data4)
print(data5.shape)
print(data5)
print(data6.shape)
print(data6)# tensor([[[1, 6, 6, 1],
#          [3, 1, 8, 2],
#          [0, 4, 7, 3]],
# 
#         [[4, 7, 5, 6],
#          [5, 4, 0, 2],
#          [8, 0, 3, 0]]])
# tensor([[[5, 2, 7, 2],
#          [7, 4, 2, 0],
#          [8, 5, 5, 9]],
# 
#         [[7, 1, 5, 6],
#          [3, 5, 4, 7],
#          [1, 0, 8, 8]]])
# --------------------
# torch.Size([2, 2, 3, 4])
# tensor([[[[1, 6, 6, 1],
#           [3, 1, 8, 2],
#           [0, 4, 7, 3]],
# 
#          [[4, 7, 5, 6],
#           [5, 4, 0, 2],
#           [8, 0, 3, 0]]],
# 
# 
#         [[[5, 2, 7, 2],
#           [7, 4, 2, 0],
#           [8, 5, 5, 9]],
# 
#          [[7, 1, 5, 6],
#           [3, 5, 4, 7],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 2, 3, 4])
# tensor([[[[1, 6, 6, 1],
#           [3, 1, 8, 2],
#           [0, 4, 7, 3]],
# 
#          [[5, 2, 7, 2],
#           [7, 4, 2, 0],
#           [8, 5, 5, 9]]],
# 
# 
#         [[[4, 7, 5, 6],
#           [5, 4, 0, 2],
#           [8, 0, 3, 0]],
# 
#          [[7, 1, 5, 6],
#           [3, 5, 4, 7],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 3, 2, 4])
# tensor([[[[1, 6, 6, 1],
#           [5, 2, 7, 2]],
# 
#          [[3, 1, 8, 2],
#           [7, 4, 2, 0]],
# 
#          [[0, 4, 7, 3],
#           [8, 5, 5, 9]]],
# 
# 
#         [[[4, 7, 5, 6],
#           [7, 1, 5, 6]],
# 
#          [[5, 4, 0, 2],
#           [3, 5, 4, 7]],
# 
#          [[8, 0, 3, 0],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 3, 4, 2])
# tensor([[[[1, 5],
#           [6, 2],
#           [6, 7],
#           [1, 2]],
# 
#          [[3, 7],
#           [1, 4],
#           [8, 2],
#           [2, 0]],
# 
#          [[0, 8],
#           [4, 5],
#           [7, 5],
#           [3, 9]]],
# 
# 
#         [[[4, 7],
#           [7, 1],
#           [5, 5],
#           [6, 6]],
# 
#          [[5, 3],
#           [4, 5],
#           [0, 4],
#           [2, 7]],
# 
#          [[8, 1],
#           [0, 0],
#           [3, 8],
#           [0, 8]]]])

可以看到dim设置为几,就会按第几个维度进行堆叠拼接,dim为0则是整体堆叠后升维,dim为1则是按第二个维度也就是后两维张量为一个整体进行两个张量对应堆叠拼接,dim为2为按后两维中的行进行堆叠拼接,dim为3也就是按两个张量的单个值进行对应堆叠拼接

stack()随着维度增加,理解会较为复杂,具体可见代码和结果演示

注意,cat()和stack()中的dim参数也可以使用负索引,即从-1开始进行维度索引

http://www.lryc.cn/news/428675.html

相关文章:

  • A. X(质因数分解+并查集)
  • 自动化测试中如何应对网页弹窗的挑战!
  • Redission
  • 负载均衡详解
  • Swift与UIKit:构建卓越用户界面的艺术
  • Spring 中ClassPathXmlApplicationContext
  • Springboot邮件发送:如何配置SMTP服务器?
  • 二叉树--堆
  • 【K8s】专题十二(2):Kubernetes 存储之 PersistentVolume
  • python3多个图片合成一个pdf文件,生产使用验证过
  • Stable Diffusion赋能“黑神话”——助力悟空走进AI奇幻世界
  • 微信小程序登陆
  • SQL - 存储过程
  • RabbitMQ环境搭建
  • 多视点抓取(Multi-View Grasping)
  • 【人工智能】对智元机器人发布的远征A1所应用的AI前沿技术进行详细分析,基于此整理一份学习教程。
  • 影刀RPA--如何获取网页当页数据?
  • Bean对象生命周期流程图
  • 24/8/17算法笔记 策略梯度reinforce算法
  • 【Linux学习】Linux开发工具——vim
  • 【2025校招】4399 NLP算法工程师笔试题
  • 数据库原理--关系1
  • 【人工智能】AI工程化是将人工智能技术转化为实际应用、创造实际价值的关键步骤
  • 《C语言实现各种排序算法》
  • 【888题竞赛篇】第五题,2023ICPC澳门-传送(Teleportation)
  • javascript写一个页码器-SAAS本地化及未来之窗行业应用跨平台架构
  • 微信小程序如何自定义一个组件
  • 【数学建模备赛】Ep05:斯皮尔曼spearman相关系数
  • MATLAB进行神经网络建模的案例
  • 每天一个数据分析题(四百八十九)- 主成分分析与因子分析