当前位置: 首页 > news >正文

【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)

文章目录

  • 一、前言
  • 二、实验环境
  • 三、PyTorch数据结构
    • 1、Tensor(张量)
      • 1. 维度(Dimensions)
      • 2. 数据类型(Data Types)
      • 3. GPU加速(GPU Acceleration)
    • 2、张量的数学运算
      • 1. 向量运算
      • 2. 矩阵运算
      • 3. 向量范数、矩阵范数、与谱半径详解
      • 4. 一维卷积运算
      • 5. 二维卷积运算
      • 6. 高维张量
    • 3、张量的统计计算
    • 4、张量操作
      • 1. 张量变形
      • 2. 索引
      • 3. 切片
      • 4. 张量修改
        • a. 张量拆分
          • split
          • unbind
          • chunk
        • b. 张量扩展
          • repeat
          • cat
          • stack
        • c. 张量修改
          • 使用索引和切片进行修改
          • gather
          • scatter

一、前言

  本文将介绍PyTorch中张量的拆分(split、unbind、chunk)、拓展(repeat、cat、stack)、修改操作(使用索引和切片、gather、scatter)

二、实验环境

  本系列实验使用如下环境

conda create -n DL python==3.11
conda activate DL
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia

三、PyTorch数据结构

1、Tensor(张量)

  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。

1. 维度(Dimensions)

  Tensor(张量)的维度(Dimensions)是指张量的轴数或阶数。在PyTorch中,可以使用size()方法获取张量的维度信息,使用dim()方法获取张量的轴数。

在这里插入图片描述

2. 数据类型(Data Types)

  PyTorch中的张量可以具有不同的数据类型:

  • torch.float32或torch.float:32位浮点数张量。
  • torch.float64或torch.double:64位浮点数张量。
  • torch.float16或torch.half:16位浮点数张量。
  • torch.int8:8位整数张量。
  • torch.int16或torch.short:16位整数张量。
  • torch.int32或torch.int:32位整数张量。
  • torch.int64或torch.long:64位整数张量。
  • torch.bool:布尔张量,存储True或False。

【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)

3. GPU加速(GPU Acceleration)

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

2、张量的数学运算

  PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。

1. 向量运算

【深度学习】Pytorch 系列教程(三):PyTorch数据结构:2、张量的数学运算(1):向量运算(加减乘除、数乘、内积、外积、范数、广播机制)

2. 矩阵运算

【深度学习】Pytorch 系列教程(四):PyTorch数据结构:2、张量的数学运算(2):矩阵运算及其数学原理(基础运算、转置、行列式、迹、伴随矩阵、逆、特征值和特征向量)

3. 向量范数、矩阵范数、与谱半径详解

【深度学习】Pytorch 系列教程(五):PyTorch数据结构:2、张量的数学运算(3):向量范数(0、1、2、p、无穷)、矩阵范数(弗罗贝尼乌斯、列和、行和、谱范数、核范数)与谱半径详解

4. 一维卷积运算

【深度学习】Pytorch 系列教程(六):PyTorch数据结构:2、张量的数学运算(4):一维卷积及其数学原理(步长stride、零填充pad;宽卷积、窄卷积、等宽卷积;卷积运算与互相关运算)

5. 二维卷积运算

【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理

6. 高维张量

【深度学习】pytorch教程(八):PyTorch数据结构:2、张量的数学运算(6):高维张量:乘法、卷积(conv2d~ 四维张量;conv3d~五维张量)

3、张量的统计计算

【深度学习】Pytorch教程(九):PyTorch数据结构:3、张量的统计计算详解

4、张量操作

1. 张量变形

【深度学习】Pytorch教程(十):PyTorch数据结构:4、张量操作(1):张量变形操作

2. 索引

3. 切片

【深度学习】Pytorch 教程(十一):PyTorch数据结构:4、张量操作(2):索引和切片操作

4. 张量修改

a. 张量拆分
split

  沿指定维度将张量拆分为多个张量

import torch# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y1, y2 = x.split(2, dim=1)
print(y1)  
print(y2)  
unbind

  沿指定维度对张量进行拆分,返回拆分后的张量列表

import torchx = torch.tensor([[1, 2, 3], [4, 5, 6]])y1, y2 = x.unbind(dim=0)
print(y1) 
print(y2)  

在这里插入图片描述

chunk

  沿指定维度将张量均匀分割为多个张量

import torch# 创建一个张量
x = torch.tensor([[1, 2, 3, 4, 5, 6]])# 沿指定维度均匀分割为多个张量
y = x.chunk(3, dim=1)
for chunk in y:print(chunk) 

在这里插入图片描述

b. 张量扩展
repeat

  复制张量中的元素进行重复操作

import torchx = torch.tensor([[1, 2, 3], [4, 5, 6]])# 重复操作
y = x.repeat(1, 2)
print(y)
z = x.repeat(2, 2)
print(z)

在这里插入图片描述

cat

  沿指定维度对多个张量进行拼接

import torchx1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]])# 在指定维度上进行拼接
y = torch.cat((x1, x2), dim=0)
print(y)  
stack

  沿新的维度对多个张量进行堆叠

import torch# 创建两个张量
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]])# 在新维度上进行堆叠
y = torch.stack((x1, x2), dim=0)
print(y)  

在这里插入图片描述

c. 张量修改
使用索引和切片进行修改

  可以使用索引和切片操作来修改张量中的特定元素或子集

import torchx = torch.tensor([[1, 2, 3], [4, 5, 6]])
x[0, 1] = 9  # 修改第0行、第1列的元素为9
print(x)
  • 输出:
tensor([[1, 9, 3],[4, 5, 6]])
gather

  按指定索引从输入张量中收集指定维度的值

import torchx = torch.tensor([[1, 2, 3], [4, 5, 6]])# 按索引收集值
indices = torch.tensor([[0, 0, 1], [1, 0, 0]])
y = torch.gather(x, 1, indices)
print(y) 
tensor([[1, 1, 2],[5, 4, 4]])
scatter

  将值按指定索引散射到新张量中

import torchx = torch.zeros(2, 4)# 按索引散射值
indices = torch.tensor([[0, 1], [2, 3]])
values = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
y = x.scatter(1, indices, values)
print(y)
tensor([[1., 2., 0., 0.],[0., 0., 3., 4.]])
http://www.lryc.cn/news/306090.html

相关文章:

  • 适合新手博主站长使用的免费响应式WordPress博客主题JianYue
  • FPGA OSERDESE2
  • 如何卸载Erlang以及RabbitMQ
  • ros自定义action记录
  • 挑战30天学完Python:Day18 正则表达式
  • 力扣● 343. 整数拆分 ● 96.不同的二叉搜索树
  • 游戏同步+游戏中的网络模块
  • 【03】逆序数组
  • 基于Prony算法的系统参数辨识matlab仿真
  • 创建第一个React项目
  • Redis篇之Redis持久化的实现
  • dpdk环境搭建和工作原理
  • 接口测试实战--自动化测试流程
  • babylonjs中文文档
  • WordPress使用
  • IDEA 2021.3激活
  • 进度条小程序
  • K8S安装部署
  • AI大模型与小模型之间的“脱胎”与“反哺”(第一篇)
  • C#学习总结
  • 计算机网络-网络互联
  • 免费的ChatGPT网站( 7个 )
  • Opencv3.2 ubuntu20.04安装过程
  • OpenGL ES (OpenGL) Compute Shader 计算着色器是怎么用的?
  • Python爬虫进阶:爬取在线电视剧信息与高级检索
  • Floor报错原理详解+sql唯一约束性
  • Arduino中安装ESP32网络抽风无法下载 暴力解决办法 python
  • Linux基础命令—系统服务
  • qt-动画圆圈等待-LED数字
  • SpringBoot3整合Swagger3,访问出现404错误问题(未解决)