当前位置: 首页 > article >正文

PyTorch中 torch.utils.data.DataLoader 的详细解析和读取点云数据示例

一、DataLoader 是什么?

torch.utils.data.DataLoader 是 PyTorch 中用于加载数据的核心接口,它支持:

  • 批量读取(batch)
  • 数据打乱(shuffle)
  • 多线程并行加载(num_workers)
  • 自动将数据打包成 batch
  • 数据预处理和增强(搭配 Dataset 使用)

二、常见参数详解

参数含义
dataset传入的 Dataset 对象(如自定义或 torchvision.datasets
batch_size每个 batch 的样本数量
shuffle是否打乱数据(通常训练集为 True)
num_workers并行加载数据的线程数(越大越快,但依机器决定)
drop_last是否丢弃最后一个不足 batch_size 的 batch
pin_memory若为 True,会将数据复制到 CUDA 的 page-locked 内存中(加速 GPU 训练)
collate_fn自定义打包 batch 的函数(可用于变长序列、图神经网络等)
sampler控制数据采样策略,不能与 shuffle 同时使用
persistent_workers若为 True,worker 在 epoch 间保持运行状态(提高效率,PyTorch 1.7+)

三、基本使用示例

搭配 Dataset 使用

from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self):self.data = [i for i in range(100)]def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]dataset = MyDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)for batch in loader:print(batch)

四、自定义 collate_fn 示例

适用于:变长数据(如文本、点云)或特殊处理需求

from torch.nn.utils.rnn import pad_sequencedef my_collate_fn(batch):# 假设每个样本是 list 或 tensor(变长)batch = [torch.tensor(item) for item in batch]padded = pad_sequence(batch, batch_first=True, padding_value=0)return paddedloader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)

五、使用注意事项

  1. Windows 平台注意:

    • 设置 num_workers > 0 时,必须使用:

      if __name__ == '__main__':DataLoader(...)
      
  2. 过多线程数可能导致瓶颈:

    • 通常 num_workers = cpu_count() // 2 较稳定
  3. GPU 加速:

    • 训练时推荐设置 pin_memory=True 可提高 GPU 训练数据传输效率。
  4. 不要同时设置 shuffle=Truesampler

    • 否则会报错,二者功能冲突。

六、训练中的典型使用方式

for epoch in range(num_epochs):for i, batch in enumerate(train_loader):inputs, labels = batchinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()

七、调试技巧与加速建议

场景建议
数据加载慢增加 num_workers
GPU 等数据设置 pin_memory=True
Dataset 中有耗时操作考虑预处理或使用缓存
debug 模式设置 num_workers=0,禁用多进程

八、与 TensorDataset、ImageFolder 配合

from torchvision.datasets import ImageFolder
from torchvision import transformstransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])dataset = ImageFolder(root='your/image/folder', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

九、点云数据处理场景应用实例

点云数据处理 场景中,使用 torch.utils.data.DataLoader 时,常遇到如下需求:

  • 每帧点云大小不同(变长 Tensor)
  • 点云数据 + 标签(如语义、实例)
  • 使用 .bin.pcd.npy 等格式加载
  • 数据增强(如旋转、裁剪、噪声)
  • GPU 加速 + 批量训练

1. 点云数据 Dataset 示例(以 .npy 文件为例)

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoaderclass PointCloudDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])self.transform = transformdef __len__(self):return len(self.files)def __getitem__(self, idx):point_cloud = np.load(os.path.join(self.root_dir, self.files[idx]))  # shape: [N, 3] or [N, 6]point_cloud = torch.tensor(point_cloud, dtype=torch.float32)if self.transform:point_cloud = self.transform(point_cloud)return point_cloud

2. 自定义 collate_fn(处理变长点云)

def collate_pointcloud_fn(batch):"""输入: List of [N_i x 3] tensors输出: - 合并后的 [B x N_max x 3] tensor- 每个样本的真实点数 list"""max_points = max(pc.shape[0] for pc in batch)padded = torch.zeros((len(batch), max_points, batch[0].shape[1]))lengths = []for i, pc in enumerate(batch):lengths.append(pc.shape[0])padded[i, :pc.shape[0], :] = pcreturn padded, torch.tensor(lengths)

3. 加载器构建示例

dataset = PointCloudDataset("/path/to/your/pointclouds")loader = DataLoader(dataset,batch_size=8,shuffle=True,num_workers=4,pin_memory=True,collate_fn=collate_pointcloud_fn
)for batch_points, batch_lengths in loader:# batch_points: [B, N_max, 3]# batch_lengths: [B]print(batch_points.shape)

4. 可选扩展功能

功能实现方法
点云旋转/缩放自定义 transform(例如随机旋转矩阵乘点云)
加载 .pcd使用 open3d, pypcd, 或 pclpy
同时加载标签在 Dataset 中返回 (point_cloud, label),修改 collate_fn
voxel downsampling使用 open3d.geometry.VoxelDownSample
GPU 加速point_cloud = point_cloud.cuda(non_blocking=True)

5. 训练循环中使用

for epoch in range(num_epochs):for batch_pc, batch_len in loader:batch_pc = batch_pc.to(device)# 可用 batch_len 做 mask 或 attention maskout = model(batch_pc)...

http://www.lryc.cn/news/2396234.html

相关文章:

  • 直线模组在手术机器人中有哪些技术挑战?
  • RK3568DAYU开发板-平台驱动开发--UART
  • ubuntu 安装 Redis 5.0.8 的完整步骤
  • 制造企业搭建AI智能生产线怎么部署?
  • 深度学习驱动的超高清图修复技术——综述
  • unix/linux source 命令,其内部结构机制
  • 【LLM】FastAPI入门教程
  • 进程同步机制-信号量机制-记录型信号量机制中的的wait和signal操作
  • gitlib 常见命令
  • Azure DevOps 管道部署系列之二IIS
  • Vue.js教学第十七章:Vue 与后端交互(一),Axios 基础
  • 人工智能浪潮下,制造企业如何借力DeepSeek实现数字化转型?
  • NodeJS全栈开发面试题讲解——P2Express / Nest 后端开发
  • 从线性代数到线性回归——机器学习视角
  • 计算机网络相关发展以及常见性能指标
  • 通义灵码:基于MCP的火车票小助手系统全流程设计与技术总结
  • 为什么建立 TCP 连接时,初始序列号不固定?
  • VBA数据库解决方案二十:Select表达式From区域Where条件Order by
  • NX753NX756美光科技闪存NX784NX785
  • 使用 pytesseract 构建一个简单 OCR demo
  • Cesium快速入门到精通系列教程三:添加物体与3D建筑物
  • git 如何解决分支合并冲突(VS code可视化解决+gitLab网页解决)
  • 【CF】Day72——Codeforces Round 890 (Div. 2) CDE1 (二分答案 | 交互 + 分治 | ⭐树上背包)
  • 单片机寄存器的四种主要类型!
  • 智能嗅探AJAX触发:机器学习在动态渲染中的创新应用
  • 【计算机网络】Linux下简单的UDP服务器(超详细)
  • Java并发编程实战 Day 3:volatile关键字与内存可见性
  • 华为OD机试真题——报文回路(2025A卷:100分)Java/python/JavaScript/C/C++/GO最佳实现
  • K8s工作流程与YAML实用指南
  • 功能丰富的PDF处理免费软件推荐