torch.cat((A,B),dim=1)解析
官方说明torch.cat
引用自:Pytorch中的torch.cat()函数
torch.cat(tensors, dim=0, *, out=None) → Tensor
# 连接给定维数的给定序列的序列张量。所有张量要么具有相同的形状(除了连接维度),要么为空。
示例
输入:
import torch
a = torch.Tensor(2,3) # (2行,3列)
b = torch.Tensor(2,3)
print (a)
print (b)
输出:
tensor([[8.9082e-39, 1.0194e-38, 9.1837e-39],[8.4490e-39, 9.6429e-39, 8.4490e-39]])
tensor([[-2.0541e-05, 5.0727e-43, -2.0541e-05],[ 5.0727e-43, -2.1039e-05, 5.0727e-43]])
输入:
print(torch.cat([a,b], dim= 0))
# 1. torch.cat((x,y),dim=0) :张量 X,Y按照列堆起来
输出:
tensor([[ 8.9082e-39, 1.0194e-38, 9.1837e-39],[ 8.4490e-39, 9.6429e-39, 8.4490e-39],[-2.0541e-05, 5.0727e-43, -2.0541e-05],[ 5.0727e-43, -2.1039e-05, 5.0727e-43]])
输入:
print(torch.cat([a,b], dim=-1))
# 2. torch.cat((x,y),dim=1) :张量 X,Y按照行并排起来
输出:
tensor([[ 8.9082e-39, 1.0194e-38, 9.1837e-39, -2.0541e-05, 5.0727e-43,-2.0541e-05],[ 8.4490e-39, 9.6429e-39, 8.4490e-39, 5.0727e-43, -2.1039e-05,5.0727e-43]])
总结:
torch.cat((x,y),dim=0)
:张量 X,Y按照列堆起来
torch.cat((x,y),dim=1)
:张量 X,Y按照行并排起来