当前位置: 首页 > article >正文

Pytorch中一些重要的经典操作和简单讲解

Pytorch中一些重要的经典操作和简单讲解

形状变换操作

reshape() / view()

import torchx = torch.randn(2, 3, 4)
print(f"原始形状: {x.shape}")# reshape可以处理非连续张量
y = x.reshape(6, 4)
print(f"reshape后: {y.shape}")# view要求张量在内存中连续
z = x.view(2, 12)
print(f"view后: {z.shape}")

transpose() / permute()

# transpose交换两个维度
x = torch.randn(2, 3, 4)
y = x.transpose(0, 2)  # 交换第0和第2维
print(f"transpose后: {y.shape}")  # torch.Size([4, 3, 2])# permute重新排列所有维度
z = x.permute(2, 0, 1)  # 将维度重排为 (4, 2, 3)
print(f"permute后: {z.shape}")

拼接和分割操作

cat() / stack()

# cat在现有维度上拼接
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)# 在第0维拼接
cat_dim0 = torch.cat([x1, x2], dim=0)  # (4, 3)
# 在第1维拼接
cat_dim1 = torch.cat([x1, x2], dim=1)  # (2, 6)# stack创建新维度并拼接
stacked = torch.stack([x1, x2], dim=0)  # (2, 2, 3)

chunk() / split()

x = torch.randn(6, 4)# chunk均匀分割
chunks = torch.chunk(x, 3, dim=0)  # 分成3块,每块(2, 4)# split按指定大小分割
splits = torch.split(x, 2, dim=0)  # 每块大小为2
splits_uneven = torch.split(x, [1, 2, 3], dim=0)  # 不均匀分割

索引和选择操作

gather() / scatter()

# gather根据索引收集元素
x = torch.randn(3, 4)
indices = torch.tensor([[0, 1], [2, 3], [1, 0]])
gathered = torch.gather(x, 1, indices)  # (3, 2)# scatter根据索引分散元素
src = torch.randn(3, 2)
scattered = torch.zeros(3, 4).scatter_(1, indices, src)

masked_select() / where()

x = torch.randn(3, 4)
mask = x > 0# 选择满足条件的元素
selected = torch.masked_select(x, mask)# 条件选择
y = torch.randn(3, 4)
result = torch.where(mask, x, y)  # mask为True选x,否则选y

数学运算操作

clamp() / clip()

x = torch.randn(3, 4)# 限制数值范围
clamped = torch.clamp(x, min=-1, max=1)
# 等价于
clipped = torch.clip(x, -1, 1)

norm() / normalize()

x = torch.randn(3, 4)# 计算范数
l2_norm = torch.norm(x, p=2, dim=1)  # L2范数
l1_norm = torch.norm(x, p=1, dim=1)  # L1范数# 归一化
normalized = torch.nn.functional.normalize(x, p=2, dim=1)

统计运算操作

mean() / sum() / std()

x = torch.randn(3, 4, 5)# 各种统计量
mean_all = x.mean()  # 全局均值
mean_dim = x.mean(dim=1)  # 沿第1维求均值
sum_keepdim = x.sum(dim=1, keepdim=True)  # 保持维度# 最值操作
max_val, max_idx = torch.max(x, dim=1)
min_val, min_idx = torch.min(x, dim=1)

广播和重复操作

expand() / repeat()

x = torch.randn(1, 3)# expand不复制数据,只是改变视图
expanded = x.expand(4, 3)  # (4, 3)# repeat实际复制数据
repeated = x.repeat(4, 2)  # (4, 6)

tile() / repeat_interleave()

x = torch.tensor([1, 2, 3])# tile像numpy的tile
tiled = x.tile(2, 3)  # 重复2次每行,3次每列# repeat_interleave每个元素重复
interleaved = x.repeat_interleave(2)  # [1, 1, 2, 2, 3, 3]

类型转换操作

to() / type() / cast()

x = torch.randn(3, 4)# 类型转换
x_int = x.to(torch.int32)
x_float = x.type(torch.float64)
x_cuda = x.to('cuda')  # 移到GPU(如果可用)# 设备转换
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_device = x.to(device)

在深度学习领域,这类张量运算操作具有极高的应用频率,尤其在数据预处理、模型架构构建及推理后处理等关键环节中不可或缺。熟练掌握此类算子的应用逻辑,能够显著优化张量数据的处理流程,提升深度学习任务的执行效率与工程实现效能。

http://www.lryc.cn/news/2397828.html

相关文章:

  • 【容器docker】启动容器kibana报错:“message“:“Error: Cannot find module ‘./logs‘
  • 基于bp神经网络的adp算法
  • C#里与嵌入式系统W5500网络通讯(4)
  • Spring boot集成milvus(spring ai)
  • Visual Studio+SQL Server数据挖掘
  • maven项目编译时复制xml到classes目录方案
  • 通过阿里云服务发送邮件
  • Vad-R1:通过从感知到认知的思维链进行视频异常推理
  • 黑马Java面试笔记之MySQL篇(事务)
  • 群辉(synology)NAS老机器连接出现网页端可以进入,但是本地访问输入一样的账号密码是出现错误时解决方案
  • C++多重继承详解与实战解析
  • 【深度学习】实验四 卷积神经网络CNN
  • 实现一个免费可用的文生图的MCP Server
  • 无公网ip远程桌面连接不了怎么办?内网计算机让外网访问方法和问题分析
  • 【手搓一个原生全局loading组件解决页面闪烁问题】
  • CSS基础巩固-基础-选择
  • 一种在SQL Server中传递多行数据的方法
  • 【Docker 从入门到实战全攻略(一):核心概念 + 命令详解 + 部署案例】
  • github 提交失败,连接不上
  • 系统架构设计师(一):计算机系统基础知识
  • VMware安装Ubuntu全攻略
  • 清理 pycharm 无效解释器
  • 精益数据分析(92/126):指标基准化——如何判断你的数据表现是否足够优秀
  • 手机如何压缩文件为 RAR 格式:详细教程与工具推荐
  • Elasticsearch集群管理的相关工具介绍
  • 基于多尺度卷积和扩张卷积-LSTM的多变量时间序列预测
  • Java 注解式限流教程(使用 Redis + AOP)
  • C# XAML 基础:构建现代 Windows 应用程序的 UI 语言
  • Linux运维笔记:服务器感染 netools 病毒案例
  • (面试)获取View宽高的几种方式