Pytorch基础函数速查
深度学习涉及的一些基础函数:
masked_select
squeeze,unsqueeze
expand
repeat
transpose,contiguous
permute
broadcasting
cat,stack
split,chunk
import torchprint('-------------------masked_select--------------------')
x = torch.randn(3,4)
print(x)
# mask = x > 0
mask = x.ge(0.5)
print(torch.masked_select(x, mask))# squeeze,unsqueeze
print('-------------------squeeze,unsqueeze-----------------')
# unsqueeze
print('[unsqueeze:]')
a = torch.rand(4, 1, 28, 28)
print('a.shape = %s ' % (a.shape,))
b = a.unsqueeze(0)
print('b.shape = %s ' % (b.shape,))
c = a.unsqueeze(-1)
print('c.shape = %s ' % (c.shape,))
# squeeze
print('[squeeze:]')
print('a.shape = %s ' % (a.shape,))
b = b.squeeze()
print('b.shape = %s ' % (b.shape,))
# 维度非1无法压缩
c1 = c.squeeze(0)
print('c1.shape = %s ' % (c1.shape,))
c2 = c.squeeze(1)
print('c2.shape = %s ' % (c2.shape,))# expand扩展
print('---------------------expand扩展-----------------------')
a = torch.rand(2, 3, 1)
print(a.shape)
b = a.expand(2, 3, 4)
print(b.shape)# repeat扩展
print('---------------------repeat扩展-----------------------')
a = torch.rand(1, 2, 32, 1)
print(a.shape)
b = a.repeat(2, 1, 1, 1)
print(b.shape)# transpose,contiguous
print('----------------transpose,contiguous-----------------')
a = torch.rand(4, 3, 32, 32)
print(a.shape)
# view会导致维度顺序关系变模糊,所以需要人为跟踪
# 1. transpose: 交换维度
# 2. contiguous: 确保tensor在内存中连续
a1 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32)
a2 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3)
print('a1.shape = %s\na2.shape = %s' % (a1.shape, a2.shape))
print('')
print('a == a1 : %s ' % torch.all(torch.eq(a, a1)))
print('a == a2 : %s ' % torch.all(torch.eq(a, a2)))# permute
print('---------------------permute-------------------------')
a = torch.rand(4, 3, 32, 32)
print('a.shape = ',end='')
print(a.shape)
b = a.permute(0, 1, 3, 2)
print('b.shape = %s' % (b.size(),))
b = a.permute(0, 2, 3, 1)
print('b.shape = %s' % (tuple(b.shape),))# broadcasting
print('----------------broadcasting-------------------------')
a = torch.rand(2, 3)
print('a:')
print(a)b = torch.rand(3)
print('b:')
print(b)c = a + b
print('a + b:')
print(c)a1 = torch.rand(3,6)
b1 = torch.rand(1)
a1 + b1
b2 = b1.expand(6)
a1 + b2
# c1 = torch.rand(3).expand(6)
c2 = torch.rand(3).repeat(2)
# a1 + c1
a1 + c2# cat,stack
print('---------------------cat,stack------------------------')
# cat: 拼接
# 拼接时需要将该维度外的其他维度保持一致
print('[cat:]')
a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.cat([a, b], dim=0)
print('c.shape = %s' % (c.shape,))
# stack: 堆叠
# 堆叠时需要将该维度外的其他维度保持一致
# 堆叠会创建新的维度
print('[stack:]')
c = torch.stack([a, b], dim=0)
print('c.shape = %s' % (c.shape,))# split,chunk
print('---------------------split,chunk----------------------')
# split: 分割
print('[split:]')
a = torch.rand(2, 3, 4)
b = a.split(1, dim=0)
for i in b:print(i.shape)
a = torch.rand(2, 3, 4)
b = a.split([1, 2], dim=1)
for i in b:print(i.shape)# chunk: 切割
print('[chunk:]')
a = torch.rand(2, 3, 4)
b = torch.chunk(a,1, dim=1)
for i in b:print(i.shape)