当前位置: 首页 > article >正文

pytorch小记(十九):深入理解 PyTorch 的 `torch.randint()` 与 `.long()` 转换

pytorch小记(十九):深入理解 PyTorch 的 `torch.randint` 与 `.long` 转换

    • 一、`torch.randint()` 基本概念
      • 示例:生成一个二维随机整型张量
    • 二、为什么需要调用 `.long()`
    • 三、典型场景示例
      • 1. 随机索引采样
      • 2. 伪标签生成
      • 3. 直接在 GPU 上生成 LongTensor
    • 四、`.long()` 的几种等价写法
    • 五、小结


在使用 PyTorch 进行深度学习建模或数据处理时,常常需要生成随机整数张量作为索引、伪标签或其它用途。本文将深入讲解 PyTorch 中的 torch.randint() 函数,以及为什么/如何结合 .long() 方法将张量转换为 64 位整型(LongTensor)。文末还会给出多种典型场景的实战示例,帮助你在项目中快速上手。


一、torch.randint() 基本概念

torch.randint() 用来在指定范围内均匀随机生成整数张量。它的函数签名如下:

torch.randint(low: int = 0,high: int,size: Tuple[int, ...],*,dtype: torch.dtype = torch.int64,layout: torch.layout = torch.strided,device: Optional[torch.device] = None,requires_grad: bool = False
) → Tensor
  • low:随机整数的下界(包含),默认为 0。
  • high:随机整数的上界(不包含),必须指定。
  • size:输出张量的形状,例如 (batch_size,)(2, 3)(B, C, H, W)
  • dtype:输出张量的数据类型,默认是 torch.int64(LongTensor)。
  • device:生成张量所在设备,如 'cpu' 或者 'cuda'

示例:生成一个二维随机整型张量

import torch# 在 [0, 10) 范围内,生成 2×3 的随机整数张量
x = torch.randint(0, 10, (2, 3))
print(x)
# 可能输出:
# tensor([[2, 7, 1],
#         [5, 0, 9]])
print(x.dtype)   # torch.int64 (默认 LongTensor)

二、为什么需要调用 .long()

虽然 torch.randint 默认即可生成 torch.int64 的张量,但在以下场景中,我们仍常见到 .long() 的调用:

  1. 确保索引类型
    PyTorch 中,张量索引用的必须是 LongTensor(torch.int64)。如果手动指定了其它整型(如 torch.int32torch.uint8),则需要 .long() 转换:

    idx32 = torch.randint(0, 100, (16,), dtype=torch.int32)
    print(idx32.dtype)  # torch.int32idx64 = idx32.long()
    print(idx64.dtype)  # torch.int64
    # 这样才能用 idx64 在其它张量上进行索引
    
  2. 满足损失函数要求
    例如 torch.nn.CrossEntropyLoss 要求标签(targets)是 LongTensor:

    num_classes = 10
    batch_size = 32labels = torch.randint(0, num_classes, (batch_size,))  # 默认就是 int64
    # labels = labels.long()  # 如果你不确定 dtype,可以显式调用logits = torch.randn(batch_size, num_classes)
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits, labels)
    
  3. 统一数据类型
    在复杂模型或数据管道中,手动控制 dtype 能避免莫名的类型不一致错误。显式地在生成后调用 .long(),可以给下游代码带来更好的可读性和健壮性。


三、典型场景示例

1. 随机索引采样

在自定义采样、数据重排或分批时,需要一组随机索引:

import torchnum_samples = 1000
batch_size = 64# 生成 [0, num_samples) 范围内,大小为 batch_size 的随机索引
indices = torch.randint(0, num_samples, (batch_size,)).long()# 假设 data 是一个形状为 [num_samples, ...] 的张量
data = torch.randn(num_samples, 3, 224, 224)
batch = data[indices]  # 用 long 类型索引

2. 伪标签生成

在无监督或对抗训练中,有时需要生成伪标签(fake labels):

import torch
import torch.nn as nnnum_classes = 5
batch_size = 16# 随机生成伪标签
fake_labels = torch.randint(0, num_classes, (batch_size,)).long()# 用 CrossEntropyLoss 计算损失
logits = torch.randn(batch_size, num_classes, requires_grad=True)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, fake_labels)
loss.backward()

3. 直接在 GPU 上生成 LongTensor

如果希望生成的随机张量直接存放在 GPU 上,同样可以指定 device,并明确 dtype:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size, num_classes = 32, 10# 一步到位生成 GPU 上的 LongTensor
labels = torch.randint(0, num_classes, (batch_size,),device=device, dtype=torch.int64)
print(labels.device, labels.dtype)  # cuda:0 torch.int64

四、.long() 的几种等价写法

  • tensor.long()
  • tensor.to(torch.int64)
  • tensor.type(torch.int64)

它们的效果相同,大家可根据个人或团队习惯任选其一。通常推荐使用 .long(),因为更简洁。


五、小结

  • torch.randint(low, high, size):生成位于 [low, high) 的均匀随机整数张量,默认 dtype 是 torch.int64

  • .long():将任意整型或浮点型张量转换为 torch.int64(LongTensor),常用于索引、标签或保证数据类型一致。

  • 典型用途

    1. 随机采样索引
    2. 生成分类伪标签
    3. 在 GPU 上直接生成 long 型张量
  • 最佳实践:在不确定 dtype 时显式调用 .long(),或通过 dtype=torch.int64device='cuda' 一次性完成生成。

http://www.lryc.cn/news/2379205.html

相关文章:

  • 深入解析Spring Boot与微服务架构:从入门到实践
  • 【交互 / 差分约束】
  • 宝塔面板部署前后端项目SpringBoot+Vue2
  • 现代生活健康养生新视角
  • 鸿蒙Next API17新特性学习之如何使用新增鼠标轴事件
  • 多模态大语言模型arxiv论文略读(八十一)
  • 3.4/Q2,Charls最新文章解读
  • 通过觅思文档项目实现Obsidian文章浏览器在线访问
  • Python列表全面解析:从入门到精通
  • 5月18总结
  • 赋予AI更强的“思考”能力
  • Linux Bash | Capture Output / Recall
  • 2025/5/18
  • 基于Quicker构建从截图到公网图像链接获取的自动化流程
  • LeetCode算 法 实 战 - - - 双 指 针 与 移 除 元 素、快 慢 指 针 与 删 除 有 序 数 组 中 的 重 复 项
  • uniapp自定义日历计划写法(vue2)
  • Java IO框架
  • 数据库2——查询
  • Mamba LLM 架构简介:机器学习的新范式
  • Android 性能优化入门(一)—— 数据结构优化
  • 数据库中的锁机制
  • 【网络入侵检测】基于Suricata源码分析运行模式(Runmode)
  • AI日报 - 2025年05月19日
  • Spring源码主线全链路拆解:从启动到关闭的完整生命周期
  • Linux常用命令(十四)
  • 规则联动引擎GoRules初探
  • 基于OpenCV中的图像拼接方法详解
  • AI大模型学习二十六、使用 Dify + awesome-digital-human-live2d + ollama + ChatTTS打造数字人
  • HTML-3.2 表格的跨行跨列(课表制作实例)
  • Spring Cloud Sentinel 快速入门与生产实践指南