当前位置: 首页 > news >正文

PyTorch Tensor 操作入门:转换、运算、维度变换

目录

1. Tensor 与 NumPy 数组的转换

1.1 Tensor 转换为 NumPy 数组

1.2 NumPy 数组转换为 Tensor

1.3 获取单个元素的值

2. Tensor 的基本运算

2.1 生成新 Tensor 的运算

2.2 覆盖原 Tensor 的运算

2.3 阿达玛积(逐元素乘法)

2.4 矩阵乘法

3. Tensor 的形状变换

3.1 view() 方法

3.2 reshape() 方法

4. 维度变换

4.1 transpose() 方法

4.2 permute() 方法

5. 完整代码示例

6. 总结


在深度学习中,PyTorch 的 Tensor 是核心数据结构,它类似于 NumPy 的数组,但可以在 GPU 上高效运行。除了创建 Tensor,PyTorch 还提供了丰富的操作方法,包括 Tensor 与 NumPy 数组的转换、基本运算、维度变换等。今天,我们就通过一个简单的代码示例,学习这些基本操作。

1. Tensor 与 NumPy 数组的转换

PyTorch 提供了非常方便的接口,用于在 Tensor 和 NumPy 数组之间进行转换。这在实际应用中非常有用,因为 NumPy 是 Python 中处理数组的标准库。

1.1 Tensor 转换为 NumPy 数组

t1 = torch.tensor([1, 2, 3, 4, 5])
n1 = t1.numpy()
print(n1)
  • t1.numpy():将 Tensor 转换为 NumPy 数组。注意,这种转换是浅拷贝,即 NumPy 数组和 Tensor 共享内存。

1.2 NumPy 数组转换为 Tensor

t2 = torch.tensor(n1)
print(t2)
  • torch.tensor(n1):将 NumPy 数组转换为 Tensor。这种转换是深拷贝,即生成一个新的 Tensor,不共享内存。

t3 = torch.from_numpy(n1)
print(t3)
  • torch.from_numpy(n1):将 NumPy 数组转换为 Tensor。这种转换是浅拷贝,即 Tensor 和 NumPy 数组共享内存。

1.3 获取单个元素的值

t4 = torch.tensor([18])
print(t4.item())
  • t4.item():当 Tensor 只有一个元素时,可以使用 item() 获取该元素的值。

2. Tensor 的基本运算

PyTorch 提供了丰富的运算操作,包括加法、减法、乘法和除法。这些运算可以分为两类:生成新 Tensor 的操作和覆盖原 Tensor 的操作。

2.1 生成新 Tensor 的运算

t1 = torch.randint(1, 10, (3, 2))
print(t1.add(1))
  • t1.add(1):对 t1 的每个元素加 1,结果生成一个新的 Tensor。

2.2 覆盖原 Tensor 的运算

print(t1.add_(1))
  • t1.add_(1):对 t1 的每个元素加 1,结果覆盖原 Tensor。

2.3 阿达玛积(逐元素乘法)

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = t1 * t2
print(t3)
  • t1 * t2:逐元素乘法,即对应位置的元素相乘。

2.4 矩阵乘法

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.matmul(t1, t2)
print(t3)
  • torch.matmul(t1, t2):矩阵乘法,符合矩阵乘法的规则。

3. Tensor 的形状变换

在深度学习中,经常需要对 Tensor 的形状进行变换,例如在卷积神经网络中调整输入数据的维度。PyTorch 提供了 view()reshape() 方法来实现这一点。

3.1 view() 方法

t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = t1.view(3, 2)
print(t2)
  • t1.view(3, 2):将 Tensor 的形状从 (2, 3) 变为 (3, 2)。注意,view() 要求 Tensor 的内存是连续的。

3.2 reshape() 方法

t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t3 = t1.reshape(3, 2)
print(t3)
  • t1.reshape(3, 2):与 view() 类似,但 reshape() 不要求内存是连续的。

4. 维度变换

在处理多维数据时,经常需要对 Tensor 的维度进行变换,例如在处理图像数据时交换通道维度。

4.1 transpose() 方法

t1 = torch.randint(1, 20, (3, 4, 5))
t2 = torch.transpose(t1, 0, 1)
print(t2)
  • torch.transpose(t1, 0, 1):交换 Tensor 的第 0 维和第 1 维。

4.2 permute() 方法

t3 = t1.permute(1, 0, 2)
print(t3)
  • t1.permute(1, 0, 2):可以同时交换多个维度,非常灵活。

5. 完整代码示例

import torchdef test01():t1 = torch.tensor([1,2,3,4,5])# numpy():将tensor转换为numpy数组,浅拷贝:如果要深拷贝,需要使用copy()# tensor():将numpy数组转换为tensor,深拷贝# from_numpy():将numpy数组转换为tensor,浅拷贝n1 = t1.numpy()print(n1)t2 = torch.tensor(n1)print(t2)t3 = torch.from_numpy(n1)print(t3)# item():当tensor只有一个元素时,使用item()获取该元素的值# t4 = torch.tensor(18)t4 = torch.tensor([18])print(t4)print(t4.item())# t5 = torch.tensor([18],device='cuda')# print(t5.item())def test02():torch.manual_seed(0)# tensor运算# add, sub , mul, div等,计算结果会生成新的tensor# add_, sub_, mul_, div_等,计算结果会覆盖原来的tensort1 = torch.randint(1 , 10, (3, 2))print(t1)print(t1.add(1))print(t1)print(t1.add_(1))print(t1)'''
阿达码积:两个矩阵对应位置相乘,得到一个新的矩阵
Cij = Aij * Bij
运算符号: mul或者*
矩阵运算:(m,p) * (p,n) = (m,n)
'''
def test03():t1 = torch.tensor([1,2],[3,4])t2 = torch.tensor([5,6],[7,8])t3 = t1 * t2print(t3)'''
view():改变tensor的形状,不改变tensor的数据,内存是连续的
reshape():改变tensor的形状,不改变tensor的数据,内存不连续
'''def test04():t1 = torch.tensor([1,2,3],[4,5,6])print(t1.is_contiguous())t2 = t1.view(3, 2)print(t2.is_contiguous())t3 = t1.t()print(t3)print(t3.is_contiguous())t4 = t3.view(2, 3)print(t4.is_contiguous())'''
维度变换
transpose():转置,交换张量的两个维度, 只能交换两个维度
permute(input,dims):维度变换,可以交换多个维度
'''
def test05():t1 = torch.randint(1, 20, (3, 4, 5))print(t1)t2 = torch.transpose(t1, 0, 1)print(t2)print(t2.is_contiguous())t3 = t1.permute(t1, (1, 0, 2))print(t3)print(t3.shape)if __name__ == '__main__':# test01()# test02()# test03()# test04()test05()

6. 总结

通过这篇文章,我们学习了 PyTorch 中 Tensor 的基本操作,包括:

  • 如何在 Tensor 和 NumPy 数组之间进行转换。

  • 如何进行基本的数学运算。

  • 如何改变 Tensor 的形状。

  • 如何对 Tensor 的维度进行变换。

这些操作是深度学习的基础,希望这篇文章能帮助你更好地理解和使用 PyTorch!

http://www.lryc.cn/news/584041.html

相关文章:

  • pytorch 神经网络
  • PyTorch自动微分:从基础到实战
  • 【Pandas】pandas DataFrame from_records
  • 【PyTorch】PyTorch中的数据预处理操作
  • 杰赛S65_中星微ZX296716免拆刷机教程解决网络错误和时钟问题
  • RocketMQ安装(Windows环境)
  • 零成本实现商品图换背景
  • 特征筛选步骤
  • 计算机视觉 之 数字图像处理基础
  • NAT技术(网络地址转换)
  • IPv4和IPv6双栈配置
  • CRT 不同会导致 fopen 地址不同
  • 飞书AI技术体系
  • Java 正则表达式白皮书:语法详解、工程实践与常用表达式库
  • OSPF协议:核心概念与配置要点解析
  • 栈题解——有效的括号【LeetCode】两种方法
  • ACL协议:核心概念与配置要点解析
  • LlamaFactory Demo
  • 强缓存和协商缓存详解
  • SQL进阶:自连接的用法
  • 深度探索:实时交互与增强现实翻译技术(第六篇)
  • 【郑大二年级信安小学期】Day9:XSS跨站攻击XSS绕过CSRF漏洞SSRF漏洞
  • 医院多部门协同构建知识库-指南库-预测模型三维网络路径研究
  • 【C++】第十四节—模版进阶(非类型模版参数+模板的特化+模版分离编译+模版总结)
  • OSPF实验以及核心原理全解
  • vue引入应用通义AI大模型-(一)前期准备整理思路
  • Vue+Element Plus 中按回车刷新页面问题排查与解决
  • Scala实现网页数据采集示例
  • AI 智能体:开启自动化协作新时代
  • 2025.07.09华为机考真题解析-第三题300分