当前位置: 首页 > news >正文

pytorch collate_fn测试用例

collate_fn 函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as npclass MyDataset(Dataset):def __init__(self, imgs, labels):self.imgs = imgsself.labels = labelsdef __len__(self):return len(self.imgs)def __getitem__(self, idx):img = self.imgs[idx]out_img = img.astype(np.float32)out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c  -->>  c,h,wout_label = self.labels[idx] #[4, 5] or [2, 5]return out_img, out_label#if batchsize=3
#batch is list, [3]
#batch0 tuple2  (np[3, 300, 150], np[4, 5])
#batch1 tuple2  (np[3, 300, 150], np[2, 5])
#batch2 tuple2  (np[3, 300, 150], np[4, 5])
def my_collate_fn(batch):"""Custom collate fn for dealing with batches of images that have a differentnumber of associated object annotations (bounding boxes).Arguments:batch: (tuple) A tuple of tensor images and lists of annotationsReturn:A tuple containing:1) (tensor) batch of images stacked on their 0 dim2) (list of tensors) annotations for a given image are stacked on0 dim"""targets = []imgs = []for sample in batch:imgs.append(torch.FloatTensor(sample[0]))targets.append(torch.FloatTensor(sample[1]))imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]return imgs_out, targetsimg_data = []
label_data = []nums = 34
H=300
W=150
for _ in range(nums):random_img = np.random.randint(low=0, high=255, size=(H, W, 3))nums_target = np.random.randint(low=0, high=10)random_xyxy_label = np.random.random((nums_target, 5))img_data.append(random_img)label_data.append(random_xyxy_label)dataset = MyDataset(img_data, label_data)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)for cnt, (img, label) in enumerate(dataloader):print("==>>", cnt, ",  img shape=", img.shape)for i in range(len(label)):print("label shape=", label[i].shape)

打印如下:

==>> 0 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([8, 5])
label shape= torch.Size([2, 5])
label shape= torch.Size([5, 5])
==>> 1 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([3, 5])
label shape= torch.Size([8, 5])
label shape= torch.Size([5, 5])
==>> 2 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([7, 5])
label shape= torch.Size([1, 5])
label shape= torch.Size([8, 5])
http://www.lryc.cn/news/218647.html

相关文章:

  • 【qemu逃逸】HITB2017-babyqemu 2019数字经济-qemu
  • Docker Compose学习笔记
  • 基于树 二叉树的回溯搜索算法(DPLL)
  • 【嵌入式】适用于ESP32/ESP8266远程自动烧录工具
  • 服务器遭受攻击如何处理(记录排查)
  • 分享81个工作总结PPT,总有一款适合您
  • 什么是DITA?从百度的回答说起
  • 线扫相机DALSA软件开发套件有哪些
  • Scala集合操作
  • SQL备忘--特殊状态“未知“以及“空值NULL“的判断
  • 《Pytorch新手入门》第一节-认识Tensor
  • 【JAVA学习笔记】55 - 集合-Map接口、HashMap类、HashTable类、Properties类、TreeMap类(难点)
  • Pytorch图像模型转ONNX后出现色偏问题
  • 插值表达式 {{}}
  • 白雪公主
  • 宏观角度认识递归之合并两个有序链表
  • Leetcode-509 斐波那契数列
  • 解密 docker 容器内 DNS 解析原理
  • 故障诊断模型 | Maltab实现SVM支持向量机的故障诊断
  • 开源的网站数据分析统计平台——Matomo
  • linux入门到地狱
  • 架构”4+1“视图
  • 『精』Vue 组件如何模块化抽离Props
  • JavaScript字符串字面量详细解析与代码实例
  • Android java Handler sendMessage使用Parcelable传递实例化对象,我这里传递Bitmap 图片数据
  • CTF工具PDF隐写神器wbStego4open安装和详细使用方法
  • docker镜像使用
  • 【Git】git的下载安装与使用
  • R语言中的函数27:polynom::polynomial(), deriv(),integral(),solve()多式处理函数
  • 基于STM32CubeMX和keil采用USART/UART实现非中断以及中断方式数据回环测试借助CH340以及XCOM