当前位置: 首页 > news >正文

PyTorch中张量(TensorFlow)操作方法和属性汇总详解和代码示例

1、张量的操作汇总

下面是 PyTorch 中常见的 张量操作方法汇总,包括 创建、索引、变换、数学运算、广播机制、维度操作 等内容,并附上详解和代码示例,便于系统学习与实战参考。


一、张量创建(torch.tensor 等)

import torch# 标量(0维张量)
a = torch.tensor(3.14)# 1维张量
b = torch.tensor([1.0, 2.0, 3.0])# 2维张量
c = torch.tensor([[1, 2], [3, 4]])# 全 0 / 全 1
zeros = torch.zeros((2, 3))
ones = torch.ones((3, 3))# 均匀分布随机数
rand = torch.rand((2, 2))# 正态分布
normal = torch.randn(2, 3)# 固定范围整数
int_rand = torch.randint(0, 10, (3, 3))

二、张量属性(维度、形状、类型)

x = torch.rand(2, 3)print(x.shape)    # torch.Size([2, 3])
print(x.dtype)    # torch.float32
print(x.ndim)     # 2
print(x.size())   # same as x.shape

三、索引与切片(Indexing & Slicing)

x = torch.tensor([[1, 2, 3], [4, 5, 6]])print(x[0])       # 第0行 [1, 2, 3]
print(x[:, 1])    # 所有行的第1列 [2, 5]
print(x[1, 2])    # 第1行第2列 -> 6# 修改元素
x[0, 1] = 10

四、形状变换(Reshape)

x = torch.arange(12)     # [0, 1, ..., 11]
x = x.view(3, 4)         # reshape 为 (3,4)x = x.reshape(2, 6)      # reshape 新形状
x = x.flatten()          # 转为 1D
x = x.unsqueeze(0)       # 增加一维 (batch-like)
x = x.squeeze()          # 去除多余维度

五、数学运算

元素级运算(+,-,*,/)

a = torch.tensor([1., 2., 3.])
b = torch.tensor([4., 5., 6.])print(a + b)
print(a * b)
print(torch.exp(a))
print(torch.sqrt(b))

矩阵运算

A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[5., 6.], [7., 8.]])# 点积/矩阵乘法
print(torch.matmul(A, B))   # 或 A @ B

六、广播机制(Broadcasting)

a = torch.tensor([[1], [2], [3]])   # shape (3,1)
b = torch.tensor([10, 20, 30])      # shape (3,)# 广播加法 => shape (3,3)
print(a + b)

七、维度操作:拼接、拆分、转置

a = torch.ones((2, 3))
b = torch.zeros((2, 3))# 拼接
cat = torch.cat([a, b], dim=0)  # shape (4, 3)# 拆分
split = torch.split(cat, 2, dim=0)  # 拆成两个 (2,3)# 转置
x = torch.rand(2, 3)
print(x.T)   # shape (3,2)

八、常用统计操作

x = torch.tensor([[1., 2.], [3., 4.]])print(x.sum())           # 所有元素求和
print(x.mean())          # 平均值
print(x.max(), x.min())  # 最大最小值
print(x.argmax(), x.argmin())  # 最大/最小索引

九、条件与掩码操作(布尔索引)

x = torch.tensor([1, 2, 3, 4, 5])
mask = x > 3
print(x[mask])  # 输出 [4, 5]

十、张量复制与共享

a = torch.tensor([1, 2, 3])
b = a.clone()     # 拷贝,不共享内存
c = a             # 共享内存

总结思维导图(简要)

张量操作:
├── 创建(tensor, zeros, randn, randint)
├── 属性(shape, dtype, ndim)
├── 索引(x[i], x[:, j])
├── 运算(+ - * /, matmul, exp, log)
├── 广播(自动扩展维度)
├── 变换(view, reshape, squeeze, unsqueeze)
├── 拼接/分割(cat, stack, split, chunk)
├── 统计分析(sum, mean, max, argmax)
├── 条件掩码(mask, where)
└── 复制(clone, detach)

2、张量属性汇总

PyTorch 中张量(Tensor)是数据的基本结构。每个张量都具有一组属性(Attributes)来描述它的维度、数据类型、设备、存储结构等信息。


张量的核心属性汇总

属性名含义说明示例值
.shape / .size()张量的形状(各维度大小)torch.Size([3, 4])
.ndim张量的维度数(维度的数量)2
.dtype数据类型torch.float32
.device存储设备:CPU 或 GPUcpu, cuda:0
.requires_grad是否记录梯度(用于自动求导)True/False
.is_leaf是否是叶子节点(可用于反向传播分析)True/False
.grad梯度值(反向传播后赋值)None 或张量
.data原始数据(无梯度跟踪)张量数据
.T转置(仅适用于二维张量)tensor.T
.storage()存储对象(高级属性)FloatStorage

示例代码:张量属性全面演示

import torch# 创建张量
a = torch.randn(3, 4, dtype=torch.float32, requires_grad=True)print("张量 a:\n", a)
print("\n== 张量属性 ==")
print("形状 shape:", a.shape)
print("维度 ndim:", a.ndim)
print("数据类型 dtype:", a.dtype)
print("所在设备 device:", a.device)
print("是否需要梯度 requires_grad:", a.requires_grad)
print("是否为叶子节点 is_leaf:", a.is_leaf)
print("梯度 grad:", a.grad)
print("原始数据 data:\n", a.data)
print("转置 T:\n", a.T)
print("存储 storage:", a.storage())

进阶示例:查看 GPU 张量属性 + 修改属性

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")b = torch.ones((2, 3), dtype=torch.float64, device=device)
print("\n张量 b 属性:")
print("b.shape:", b.shape)
print("b.dtype:", b.dtype)
print("b.device:", b.device)# 修改数据类型和设备
b2 = b.to(dtype=torch.float32, device='cpu')
print("\n修改后:")
print("b2.dtype:", b2.dtype)
print("b2.device:", b2.device)

示例:grad 与 requires_grad 的关系

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2
z = y.sum()print("z:", z)
z.backward()
print("x.grad:", x.grad)  # dz/dx = 2x

Tips

  • .shape.size() 等价,都返回 torch.Size 对象。
  • .is_leaf 为 True 的张量通常是直接创建的变量,反向传播时只保留其梯度。
  • .data 会返回张量本身的数据,但不参与自动求导

http://www.lryc.cn/news/590018.html

相关文章:

  • Postman接口
  • 【开源.NET】一个 .NET 开源美观、灵活易用、功能强大的图表库
  • GraphQL与REST在微服务接口设计中的对比分析与实践
  • Nacos 开源 MCP Router,加速 MCP 私有化部署
  • Linux开发利器:探秘开源,构建高效——基础开发工具指南(上)【包管理器/Vim】
  • 【Fastapi】Token验证与Postman模拟测试
  • HTTP REST API、WebSocket、 gRPC 和 GraphQL 应用场景和底层实现
  • IPv6
  • JavaScript进阶篇——第六章 内置构造函数与内置方法
  • qt 中英文翻译 如何配置和使用
  • AR智能巡检:电力行业数字化转型的“加速器”
  • 二分查找法
  • 力扣面试150(31/150)
  • 坐标系和相机标定介绍,张正友标定法原理,opencv标定
  • C++:现代 C++ 编程基石,C++11核心特性解析与实践
  • NLP:LSTM和GRU分享
  • NO.6数据结构树|二叉树|满二叉树|完全二叉树|顺序存储|链式存储|先序|中序|后序|层序遍历
  • 从零开始的云计算生活——番外4,使用 Keepalived 实现 MySQL 高可用
  • PyTorch 损失函数详解:从理论到实践
  • 《通信原理》学习笔记——第二章
  • Qt小组件 - 7 SQL Thread Qt访问数据库ORM
  • qt udp接收时 丢包
  • FreeRTOS学习笔记之任务调度
  • 《机器学习数学基础》补充资料:标准差与标准化
  • 《Qt信号与槽机制》详解:从基础到实践
  • Qt中实现文件(文本文件)内容对比
  • 若依框架下前后端分离项目交互流程详解
  • ScratchCard刮刮卡交互元素的实现
  • MR 处于 WIP 状态的WIP是什么
  • Django+Celery 进阶:Celery可视化监控与排错