当前位置: 首页 > news >正文

tensor

😉如果您想用jupyter notebook跑我的笔记,可以在下面获取ipynb版本
😊麻烦给个免费的star😘
❤️主包也更建议这种形式,上面的笔记也更加全面,每一步都有直观的输出

文章目录

  • 📚 PyTorch张量操作与自动微分全面指南
    • 🛠 1. 准备工作
      • 1.1 导入PyTorch并检查环境
      • 1.2 设置计算设备
    • 📦 1.1 张量初始化
      • 1.1.1 从Python数据类型创建
      • 1.1.2 数据类型处理
      • 1.1.3 从NumPy数组创建
      • 1.1.4 特殊张量创建
      • 1.1.5 张量属性查看
      • 1.1.6 张量设备转移
    • ➗ 1.2 张量运算
      • 1.2.1 基本运算
      • 1.2.2 张量转换
      • 1.2.3 形状操作
    • 🔁 1.3 张量自动微分
      • 1.3.1 梯度计算核心概念
      • 1.3.2 梯度计算实战
      • 1.3.3 梯度控制技巧
    • 💎 核心要点总结

📚 PyTorch张量操作与自动微分全面指南

PyTorch作为深度学习领域的主流框架,掌握其核心数据结构张量(Tensor)的操作至关重要。本文将全面解析PyTorch张量的创建、运算和自动微分机制,助你快速上手PyTorch开发!🚀

🛠 1. 准备工作

1.1 导入PyTorch并检查环境

# 导入torch库
import torch# 查看torch版本
print(torch.__version__)  # 输出: 2.6.0+cu124# 检查CUDA是否可用
torch.cuda.is_available()  # 输出: True

1.2 设置计算设备

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)  # 输出: cuda

📦 1.1 张量初始化

1.1.1 从Python数据类型创建

data = [[1,2],[3,4]]
x = torch.tensor(data)
print(x)
# 输出: 
# tensor([[1, 2],
#         [3, 4]])

1.1.2 数据类型处理

# 查看数据类型
print(x.dtype)  # 输出: torch.int64# 指定数据类型创建
x1 = torch.FloatTensor([[1,2],[3,4]])  # 等价于 torch.tensor(..., dtype=torch.float32)
x2 = torch.LongTensor([[1,2],[3,4]])# 强制类型转换
x1 = x1.type(torch.long)  # 转换为int64
x1 = x1.float()           # 转换为float32

💡 关键提示torch.float32torch.int64是最重要的数据类型!模型输入通常是float32,分类标签通常是int64

1.1.3 从NumPy数组创建

import numpy as np
data = np.array([[1,2],[4,3]])
x3 = torch.from_numpy(data)
print(x3)
# 输出:
# tensor([[1, 2],
#         [4, 3]])

1.1.4 特殊张量创建

# 创建随机张量
rand_tensor = torch.rand(2,3)    # 0~1均匀分布
randn_tensor = torch.randn(2,3)  # 标准正态分布# 创建全1/全0张量
ones_tensor = torch.ones(2,3)
zeros_tensor = torch.zeros(2,3)# 继承形状创建
y1 = torch.rand_like(x.float())

1.1.5 张量属性查看

t = x
print(t.shape)      # 形状: torch.Size([2, 2])
print(t.size())     # 大小: torch.Size([2, 2])
print(t.dtype)      # 数据类型: torch.int64
print(t.device)     # 存储设备: cpu

1.1.6 张量设备转移

x_gpu = x.to(device)  # 转移到GPU
print(x_gpu.device)   # 输出: cuda:0

➗ 1.2 张量运算

1.2.1 基本运算

# 就地加法 (会改变原张量)
t1 = torch.randn(2,3)
t2 = torch.ones(2,3)
t2.add_(t1)# 矩阵乘法
result1 = t1 @ t2.T
result2 = t1.matmul(t2.T)

1.2.2 张量转换

# 单元素张量转标量
t3 = t1.sum()
scalar = t3.item()  # 输出: 2.1050655841827393# 张量转NumPy数组
numpy_array = t1.numpy()

1.2.3 形状操作

t = torch.randn(4,6)# 重塑形状
t1 = t.view(3,8)     # 显式指定形状
t2 = t.view(-1,1)    # 自动推断维度
t3 = t.view(1,4,6)   # 增加维度# 维度压缩与扩展
t4 = torch.ones(1,4,6)
t5 = torch.squeeze(t4)    # 移除大小为1的维度
t6 = torch.unsqueeze(t5,0)  # 在指定位置添加维度

🔁 1.3 张量自动微分

1.3.1 梯度计算核心概念

PyTorch的自动微分系统基于以下三个关键属性:

  • requires_grad:是否跟踪梯度(默认为False)
  • grad:存储计算得到的梯度
  • grad_fn:指向生成此张量的运算方法

1.3.2 梯度计算实战

# 创建需要跟踪梯度的张量
t = torch.ones(2,2, requires_grad=True)
print(t.requires_grad)  # 输出: True# 构建计算图
y = t + 5
x = y * 2
out = x.mean()# 反向传播计算梯度
out.backward()# 查看梯度
print(t.grad)
# 输出:
# tensor([[0.5000, 0.5000],
#         [0.5000, 0.5000]])

1.3.3 梯度控制技巧

# 临时禁用梯度跟踪
with torch.no_grad():print((t + 2).requires_grad)  # 输出: False# 获取不跟踪梯度的张量副本
detached_t = t.detach()
print(detached_t.requires_grad)  # 输出: False# 永久关闭梯度跟踪
t.requires_grad_(False)
print(t.requires_grad)  # 输出: False

💎 核心要点总结

  1. 张量创建

    • 使用torch.tensor()从Python数据创建
    • 使用torch.from_numpy()从NumPy数组转换
    • 掌握torch.rand(), torch.randn(), torch.ones(), torch.zeros()等创建方法
  2. 数据类型管理

    • 重点掌握float32(模型输入)和int64(分类标签)
    • 使用.float().long()快速转换类型
  3. 设备转移

    • 使用.to(device)在CPU/GPU间移动张量
    • 始终检查tensor.device确保张量位置正确
  4. 自动微分

    • 设置requires_grad=True启用梯度跟踪
    • 使用.backward()自动计算梯度
    • 使用with torch.no_grad():上下文管理器禁用梯度计算
  5. 性能优化

    • 尽量使用就地操作(如add_())减少内存开销
    • 合理使用view()进行形状重塑
    • 及时使用detach()分离不需要的计算图

掌握这些PyTorch张量核心操作,你已为深度学习模型开发打下坚实基础!🎯 下一步可以探索神经网络模块和优化器的使用,开启模型训练之旅!

http://www.lryc.cn/news/585386.html

相关文章:

  • Word表格默认格式修改成三线表,一劳永逸,提高生产力!
  • 上位机知识篇---高效下载安装方法
  • 05 rk3568 debian11 root用户 声音服务PulseAudio不正常
  • PyTorch 与 Spring AI 集成实战
  • 2025Nginx最新版讲解/面试
  • 【yolo】模型训练参数解读
  • 七、gateway服务创建
  • WPS、Word加载项开发流程(免费最简版本)
  • [Meetily后端框架] 多模型-Pydantic AI 代理-统一抽象 | SQLite管理
  • VLLM部署DeepSeek-LLM-7B-Chat 模型
  • Lecture #19 : Multi-Version Concurrency Control
  • Jenkins 版本升级与插件问题深度复盘:从 2.443 到 2.504.3 及功能恢复全解析
  • FPGA实现SDI转LVDS视频发送,基于GTX+OSERDES2原语架构,提供2套工程源码和技术支持
  • Java进阶---并发编程
  • 【C/C++ shared_ptr 和 unique_ptr可以互换吗?】
  • 【AI News | 20250710】每日AI进展
  • 一个中层管理者应该看什么书籍?
  • 使用Python将目录中的JPG图片按后缀数字从小到大顺序纵向拼接,很适合老师发的零散图片拼接一个图片
  • 谷歌独立站是什么?谷歌独立站建站引流完全指南
  • HarmonyOS基础概念
  • Python中类静态方法:@classmethod/@staticmethod详解和实战示例
  • C#中的设计模式:构建更加优雅的代码
  • 链接代理后无法访问网络
  • C++入门基础篇(二)
  • HandyJSON使用详情
  • 使用Spring Boot和PageHelper实现数据分页
  • Excel快捷键
  • 20250710-2-Kubernetes 集群部署、配置和验证-网络组件存在的意义?_笔记
  • leetcode:377. 组合总和 Ⅳ[完全背包]
  • 代账行业数字化破局:从“知道”到“做到”,三步走稳赢!