网站排名优化公司,网站建设价格,wordpress标签关键词,c mvc网站做404文章目录一、Dataset 与 DataLoader 功能介绍抽象类Dataset的作用DataLoader 作用两者关系二、torch.utils.data.DataLoader代码示例常用参数图示num_workers设置多少合适数据加载子进程如何并行的pin_memorysampler两种sampler顺序采样 SequentialSampler随机采样 RandomSampl…
文章目录一、Dataset 与 DataLoader 功能介绍抽象类Dataset的作用DataLoader 作用两者关系二、torch.utils.data.DataLoader代码示例常用参数图示num_workers设置多少合适数据加载子进程如何并行的pin_memorysampler两种sampler顺序采样 SequentialSampler随机采样 RandomSamplersampler 与 shuffle 的互斥RandomSampler 与 shuffleTrue 的区别batch_samplerBatchSampler 与其他参数的互斥使用举例collate_fncollate_fn 函数的作用默认 collate_fn 函数自定义 collate_fn 函数一、Dataset 与 DataLoader 功能介绍
抽象类Dataset的作用
简单来说就是将原始数据可能是图片、文本、音频等各种格式整理成模型可以处理的格式为后续的数据加载和处理做准备。功能是定义数据集的基本属性和数据获取方式。
初始化数据路径在Dataset类的__init__方法中通常会初始化数据存放的路径以及一些数据预处理的操作比如指定图片数据集图片所在文件夹路径文本数据集文本文件路径等 。包含 加载数据/读取数据、预处理数据、图像增强 等一系列操作获取单个样本及其标签通过实现__getitem__方法根据给定的索引dataloader返回的返回相应的数据样本和对应的标签。例如在图片分类任务中给定索引后返回该索引对应的图片数据经过预处理如调整尺寸、归一化等以及图片的类别标签。统计样本数量通过实现__len__方法返回数据集中样本的总数方便在训练和评估过程中知道数据规模 。
DataLoader 作用
DataLoader是在Dataset的基础上提供了一种更加高效、便捷地加载数据的方式它可以将Dataset返回的单个样本按照指定的方式进行打包如组成batch、打乱顺序等操作从而满足模型训练和评估的需求。 创建数据批次指定数据打包输出规则通过batch_size参数将Dataset中的单个样本打包成一个个批次batch的数据。 collate_fn指定如何从NNN张训练集选出一个batch的Nbatch_size\frac{N}{batch\_size}batch_sizeN张图片。例如batch_size32那么DataLoader每次会从Dataset中取出32个样本组成一个batch。每次迭代返回的是 一个batch 的数据 自定义数据采样指定数据迭代读取规则 一般使用自定义的采样器Sampler实现对数据的特殊采样方式比如分层采样在类别不均衡的数据集中保证每个batch中各类别的样本比例与原始数据集相似等。dataset对象是dataloader的一个参数通过dataset让dataloader知道训练集一共多少图片从而知道共跌代多少次。 数据打乱通过shuffle参数设置是否在每个epoch开始时打乱数据顺序这样可以避免模型在训练时对数据产生特定的依赖有助于模型学习到更通用的特征提高模型的泛化能力 。 多进程加载通过num_workers参数设置多进程加载数据从而加快数据加载速度尤其是在数据量较大、数据预处理较为复杂的情况下多进程可以充分利用CPU资源减少数据加载时间避免数据加载成为训练过程中的瓶颈 。
两者关系 Dataset是数据的基础容器定义了如何获取数据集中的单个样本 而DataLoader则是Dataset的上层应用负责按照特定规则如批量处理、打乱顺序等从Dataset中高效地加载数据供模型进行训练、验证和测试等操作。 可以说Dataset是数据的来源和基本操作接口DataLoader则是为了更好地适配模型训练需求对Dataset的数据进行进一步处理和组织的工具。
二、torch.utils.data.DataLoader
torch.utils.data.DataLoader类有很多参数可查看Pytorch官方文档torch.utils.data.DataLoader
代码示例
from torch.utils.data import DataLoaderdata_loader DataLoader(dataset, batch_size1, shuffleNone, samplerNone, batch_samplerNone, num_workers0, collate_fnNone, pin_memoryFalse, drop_lastFalse,timeout0)dataset加载数据的数据集batch_size每批返回的数据量默认值是 1shuffle是否在每个 epoch 内将数据打乱顺序。默认值为Falsesampler从数据集中提取的样本序列。可以用来自定义样本的采样策略。默认值为Nonebatch_sampler与sampler类似但是一次返回一个 batch 的索引用于自定义 batch。它与batch_size、shuffle、sampler和drop_last互斥num_workers用于数据加载的子进程数。0 表示主进程加载。默认值为0collate_fn用于指定如何组合样本数据。如果为None那么将默认使用默认的组合方法drop_last如果数据集的大小不能被batch_size整除那么是否丢弃最后一个数据批次。默认值为Falsepin_memory将数据固定在内存的锁页内存中加速数据读取的速度。默认值为Falsetimeout等待 collect 一个 batch 的数据的超时时间。默认为 0表示一直等待
常用参数图示
对于常用的参数见这个数据流向的流程图 dataset是Dataset类的对象在Dataloader中有 2个作用
通过 dataset 的 __len__ 方法dataloader 可以知道数据量从而根据数据量生成相应的索引列表dataloader 会将索引传给 dataset 的 __getitem__ 方法 __getitem__ 方法会对数据进行处理并返回处理好的数据
Dataset 与 Dataloader 的内部交互细节 举例
num_workers
设置多少合适
参数 num_workers 参数用于指定加载数据的子进程的数量这些子进程可以并行地加载数据。
num_workers0(默认值) 表示只有主进程去加载 batch 数据这个可能会是一个瓶颈处理比较慢。num_workers1表示只有一个子进程加载数据主进程不参与这仍可能导致速度慢。num_workers0表示指定数量的子进程并行加载数据且主进程不参与。
增加 num_workers 可以提高加载速度但也会增加 CPU 和 内存的使用。 通常建议将 num_workers 参数设置为等于或小于 CPU 核心数以有效平衡数据加载效率和系统资源占用率。
nw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8])
batch_size 16
nw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workers
train_dataloader torch.utils.data.DataLoader(train_dataset,batch_sizebatch_size,num_workersnw,shuffleTrue,pin_memoryTrue,collate_fncollate_fn)数据加载子进程如何并行的
一个进程仅处理一个 batch 的数据假设设置 num_workers2 则 进程1 处理一个 batch 的数据进程2 处理另一个 batch 的数据。 并行工作流程
初始化创建 DataLoader 实例时通过参数 num_workers 指定并行加载的子进程数量子进程加载数据子进程独立于主进程运行每个子进程的拿着一个 batch 的索引列表并行地到 dataset 的 getitem 中预处理数据数据准备处理好的数据放入缓冲区以备主进程请求数据请求主进程在 for 循环中请求下一个 batch数据传输主进程请求数据时从缓冲区获取已经准备好的 batch循环迭代主进程不断请求数据子进程并行的处理后续的 batch 数据
pin_memory 若设置 pin_memoryTrue 数据会被加载到CPU的内存Pinned Memory中从而提高数据从 CPU 到 GPU 的传输效率。这是因为**锁定的内存pinned memory**可以更快地被复制到GPU因为它是连续的并且已经准备好被传输。 若设置 pin_memoryFalse 则数据是被存放在**可分页内存pageable memory**中当我们想要把数据从 cpu 移动到 gpu 上执行 .to(cuda) 的时候 需要先将数据从分页内存中移动到锁页内存中然后再传输到 GPU 上 参数设置建议
设置 pin_memoryTrue 节省的是 将数据从 分页内存移动到锁页内存中 的这段时间。如果你的训练完全在CPU上进行不涉及GPU那就没有必要设置 pin_memoryTrue 。因为在这种情况下数据不需要被传输到GPU因此不需要使用锁定内存来加速这一过程。可以将 pin_memory 设置为 False 以简化内存管理。
sampler
采样器sampler控制数据集索引顺序。 torch.utils.data.DataLoader 的参数 sampler 参数接收的通常是一个实现了 Sampler 接口的对象比如
sampler SequentialSampler(dataset) # 使用 SequentialSampler
dataloader DataLoader(dataset, batch_size8, samplersampler)通过 sampler 对象来控制数据集的索引顺序从而影响数据从数据集中的抽取方式。
两种sampler
第一种为pytorch 提供的可以直接使用的几种 sampler 顺序和随机比较常用。
# 顺序抽样按照数据集的顺序逐个抽取样本
torch.utils.data.sampler.SequentialSampler()# 随机抽样数据集中的样本以随机顺序被抽取
torch.utils.data.sampler.RandomSampler()# 从指定的样本索引子集内进行随机抽样
torch.utils.data.sampler.SubsetRandomSampler()# 根据样本的权重随机抽样不同样本有不同的抽样概率
torch.utils.data.sampler.WeightedRandomSampler()可以自定义 sampler比如以下是 yolov5 中自定义的 SmartDistributedSampler的sampler类
参数 sampler 有一部分功能是和参数 shuffle 是重叠的这时用shuffle简单
顺序采样 SequentialSampler 效果等价于 shuffleFalse不打乱顺序。随机采样RandomSampler 效果等价于 shuffleTrue
Pytorch 提供 sampler 参数主要是为提升灵活性支持用户更灵活地自定义设计数据加载的方式
下面我们主要介绍 SequentialSampler 和 RandomSampler 只要大家通过 SequentialSampler 、 RandomSampler 掌握了 sampler 的工作原理便可以愉快的自定义的去设计 sampler 了。
顺序采样 SequentialSampler
作用 接收一个 Dataset 对象输出数据包中样本量的顺序索引代码小测试
import torch.utils.data.sampler as sampler# 模拟真实数据
data list([17, 22, 3, 41, 8])# 实例化sampler对象
seq_sampler sampler.SequentialSampler(data_sourcedata)for index in seq_sampler:print(index: {}.format(index))seq_sampler为一个索引列表每一次迭代都返回一个索引值。
Pytorch内部源码实现
class SequentialSampler(Sampler):data_source: Sizeddef __init__(self, data_source: Sized) - None:self.data_source data_sourcedef __iter__(self) - Iterator[int]:return iter(range(len(self.data_source)))def __len__(self) - int:return len(self.data_source)__init__ 接收参数Dataset 对象__iter__ 调用len方法获取数据集大小再用range方法生成索引列表返回一个可迭代对象返回的是索引值因为 SequentialSampler 是顺序采样所以返回的索引是顺序数值序列。__len__ 返回 dataset 中数据个数
这里再给一个Sampler和DatasetDataLoader结合使用的例子
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSamplerclass myDataset(Dataset):def __init__(self, data):self.data datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 0 到 19 的整数所以数据值和索引值一样。
data [i for i in range(20)]
dataset myDataset(data)# 使用 SequentialSampler 实例化对象
sampler SequentialSampler(dataset)# 创建 DataLoader
dataloader DataLoader(dataset, batch_size8, samplersampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)随机采样 RandomSampler
作用 接收一个 Dataset 对象输出数据包中样本量的随机索引 可指定是否可重复
import torch.utils.data.sampler as samplerdata list([17, 22, 3, 41, 8])
seq_sampler sampler.RandomSampler(data_sourcedata)for index in seq_sampler:print(index: {}.format(index))Pytorch源码删减版本
class RandomSampler(Sampler):def __init__(self, data_source, replacementFalse, num_samplesNone):self.data_source data_sourceself.replacement replacementself._num_samples num_samplesdef num_samples(self):if self._num_samples is None:return len(self.data_source)return self._num_samplesdef __len__(self):return self.num_samplesdef __iter__(self):n len(self.data_source)if self.replacement:# 生成的随机数是可能重复的return iter(torch.randint(highn, size(self.num_samples,), dtypetorch.int64).tolist())# 生成的随机数是不重复的return iter(torch.randperm(n).tolist())__init__ 参数 data_source (Dataset)采样的 Dataset 对象replacement (bool)如果为 True则抽取的样本是有放回的。默认为 Falsenum_samples (int)抽取样本的数量默认是 len(dataset)。当 replacement 是 True 时应被实例化 __iter__ 返回一个可迭代对象返回的是索引因为 RandomSampler 是随机采样所以返回的索引是随机的数值序列当 replacementFalse 时生成的排列是无重复的__len__ 返回 dataset 中样本量
从源码中可以看到随机采样和顺序采样的区别在于生成索引时用了torch.randperm(n)方法。
举例
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSamplerclass myDataset(Dataset):def __init__(self, data):self.data datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 0 到 19 的整数
data [i for i in range(20)]
dataset myDataset(data)# 使用 SequentialSampler
sampler RandomSampler(dataset)# 创建 DataLoader
dataloader DataLoader(dataset, batch_size8, samplersampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)sampler 与 shuffle 的互斥
参数 sampler 与参数 shuffle 是互斥的不要同时使用 sampler 和 shuffle
当同时设置了 shuffle 与 sampler且 shuffleTrue会报错当同时设置了 shuffle 与 sampler且 shuffleFalse(就是默认值)具体逻辑按照 sampler
因为 shuffle 的默认值为 False所以代码会兼容 shuffle 等于默认值 False 的情况
RandomSampler 与 shuffleTrue 的区别
效果完全没有区别只是实现方式不一样。
shuffleTrue 的实现方式在每个 epoch 开始时将整个数据集打乱然后按照打乱后的顺序划分 batch再按照 batch_size 个数依次提取数据sampler.BatchSampler(random_sampler) 的实现方式数据不会打乱 step 1、RandomSampler 会生成随机的索引。step 2、BatchSampler 根据上面随机出来的索引生成 batch 组。step 3、拿着每个 batch 组的索引去取数据
相同点
每个 epoch 都会重新打乱都不会重复采样除非你通过参数指定了可以重复采样
其他说明 3. shuffleTrue 的性能更高一些而 BatchSampler 灵活性更高因为你可以通过 BatchSampler 设计更复杂的采样方式 4. 在 Dataloader 中使用 batch_sampler 的常见目的之一是为了兼容 DistributedSampler比如
if args.distributed:sampler_train DistributedSampler(dataset_train)sampler_val DistributedSampler(dataset_val, shuffleFalse)
else:sampler_train torch.utils.data.RandomSampler(dataset_train)sampler_val torch.utils.data.SequentialSampler(dataset_val)batch_sampler_train torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_lastTrue)data_loader_train DataLoader(dataset_train,batch_samplerbatch_sampler_train,collate_fnutils.collate_fn,)
data_loader_val DataLoader(dataset_val,args.batch_size,samplersampler_val,drop_lastFalse,collate_fnutils.collate_fn,)跑个小例子看一下两者都是随机的效果
import torch
import torch.utils.data.sampler as sampler
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self):self.data [1, 2, 3, 4, 5]def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]dataset MyDataset()#
random_sampler sampler.RandomSampler(data_sourcedataset)
batch_sampler sampler.BatchSampler(random_sampler, batch_size2, drop_lastFalse)
dataloader1 DataLoader(dataset, batch_samplerbatch_sampler)for epoch in range(3):for index, data in enumerate(dataloader1):print(index, data)
print(**30)#
dataloader2 DataLoader(dataset, batch_size2, shuffleTrue)for epoch in range(3):for index, data in enumerate(dataloader2):print(index, data)batch_sampler
torch.utils.data.DataLoaderde 的参数 batch_sample 接收的一般是 torch.utils.data.BatchSampler 对象 torch.utils.data.BatchSampler 的作用 包装另一个采样器生成一个小批量索引采样器
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)接收三个参数和DataLoader参数重叠了所以在实例化BatchSampler时指定了 batch_size和 drop_last就不需要再在DataLoader中指定如果重复指定会报错。
sampler : 其他采样器实例batch_size 批量大小drop_last为 True时如果最后一个batch 采样得到的数据个数小于batch_size则抛弃最后一个batch的数据
BatchSampler 与其他参数的互斥
如果你在 DataLoader(dataset, batch_samplerbatch_sampler) 中指定了参数 batch_sampler那么就不能再指定参数 batch_size、shuffle、sampler、和 drop_last 了他们互斥。
因为:
你在生成 torch.utils.data.sampler.BatchSampler() 的时候就已经制定过 batch_size、sampler、和 drop_last 这些参数了batch_sampler 与 shuffle 作用一致所以也互斥
比如如下代码就会报错因为在 DataLoader 中重复指定了 batch_size random_sampler sampler.RandomSampler(data_sourcedataset)batch_sampler sampler.BatchSampler(random_sampler, batch_size2, drop_lastFalse)dataloader DataLoader(dataset, batch_size2, batch_samplerbatch_sampler)使用举例
import torch.utils.data.sampler as sampler
# 用list模拟数据
data list([17, 22, 3, 41, 8])seq_sampler sampler.SequentialSampler(data_sourcedata)
batch_sampler sampler.BatchSampler(seq_sampler, 2, False )for index in batch_sampler:print(index)每次迭代获得的是一个batch的索引列表。
Pytorch源码删减版
class BatchSampler(Sampler):def __init__(self, sampler, batch_size, drop_last):、self.sampler samplerself.batch_size batch_sizeself.drop_last drop_lastdef __iter__(self):batch []for idx in self.sampler:batch.append(idx)# 如果采样个数和batch_size相等则本次采样完成if len(batch) self.batch_size:yield batchbatch []# for 结束后在不需要剔除不足batch_size的采样个数时返回当前batch if len(batch) 0 and not self.drop_last:yield batchdef __len__(self):# 在不进行剔除时数据的长度就是采样器索引的长度if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) self.batch_size - 1) // self.batch_size例子
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, BatchSamplerclass myDataset(Dataset):def __init__(self, data):self.data datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 # 生成 0 到 19 的整数
data [i for i in range(20)]
dataset myDataset(data)# 使用 SequentialSampler 顺序采样
sequential_sampler SequentialSampler(dataset)# 使用 BatchSampler 将 SequentialSampler 和 batch_size 结合
batch_sampler BatchSampler(sequential_sampler, batch_size8, drop_lastFalse)# 创建 DataLoader使用 BatchSampler
dataloader DataLoader(dataset, batch_samplerbatch_sampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)collate_fn
在使用 torch.utils.data.dataset 时参数 collate_fn 接受一个函数该函数的函数名通常就为collate_fn
collate_fn 函数的作用
将多个 经过 dataset.getitem() 处理好的 样本数据组合成一个 batch 的数据。 注 更换 cifar-100 在你本地的路径
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import ostorch.manual_seed(121)
torch.cuda.manual_seed(121)label_dict {apple: 0,aquarium_fish: 1,baby: 2,bear: 3,beaver: 4,bed: 5,bee: 6,beetle: 7,bicycle: 8,bottle: 9,bowl: 10,boy: 11,bridge: 12,bus: 13,butterfly: 14,camel: 15,can: 16,castle: 17,caterpillar: 18,cattle: 19,chair: 20,chimpanzee: 21,clock: 22,cloud: 23,cockroach: 24,couch: 25,crab: 26,crocodile: 27,cup: 28,dinosaur: 29,dolphin: 30,elephant: 31,flatfish: 32,forest: 33,fox: 34,girl: 35,hamster: 36,house: 37,kangaroo: 38,keyboard: 39,lamp: 40,lawn_mower: 41,leopard: 42,lion: 43,lizard: 44,lobster: 45,man: 46,maple_tree: 47,motorcycle: 48,mountain: 49,mouse: 50,mushroom: 51,oak_tree: 52,orange: 53,orchid: 54,otter: 55,palm_tree: 56,pear: 57,pickup_truck: 58,pine_tree: 59,plain: 60,plate: 61,poppy: 62,porcupine: 63,possum: 64,rabbit: 65,raccoon: 66,ray: 67,road: 68,rocket: 69,rose: 70,sea: 71,seal: 72,shark: 73,shrew: 74,skunk: 75,skyscraper: 76,snail: 77,snake: 78,spider: 79,squirrel: 80,streetcar: 81,sunflower: 82,sweet_pepper: 83,table: 84,tank: 85,telephone: 86,television: 87,tiger: 88,tractor: 89,train: 90,trout: 91,tulip: 92,turtle: 93,wardrobe: 94,whale: 95,willow_tree: 96,wolf: 97,woman: 98,worm: 99
}def default_collate(batch):# 检查样本类型并处理if isinstance(batch[0], torch.Tensor):return torch.stack(batch, dim0)elif isinstance(batch[0], (list, tuple)):return [default_collate(samples) for samples in zip(*batch)]elif isinstance(batch[0], dict):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], int):return torch.tensor(batch) # 将 int 转换为 Tensorraise TypeError(fUnsupported type: {type(batch[0])})class CustomDataset(Dataset):def __init__(self, data_folder, train, transformNone):self.data_folder data_folderself.transform transformself.file_list os.listdir(data_folder)self.train traindef __getitem__(self, idx):img_name os.path.join(self.data_folder, self.file_list[idx])original_image Image.open(img_name)label_name img_name.split(_, 1)[-1].split(.)[0]label_idx label_dict[label_name]if self.train:image self.transform(original_image)else:image self.transform(original_image)return image, label_idxdef __len__(self):return len(self.file_list)images_dir /Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train
dataset CustomDataset(images_dir, trainTrue, transformtransforms.ToTensor())data_loader DataLoader(dataset,batch_size2,shuffleTrue,collate_fndefault_collate)data_loader iter(data_loader)
image, label next(data_loader)
print(image.shape)
print(label)默认 collate_fn 函数
简易实现版本实际更复杂
def default_collate(batch):# 检查样本类型并处理# 判断batch第0个元素数据类型根据不同类型分别返回不同的打包结果。if isinstance(batch[0], torch.Tensor):return torch.stack(batch, dim0)elif isinstance(batch[0], (list, tuple)):return [default_collate(samples) for samples in zip(*batch)]elif isinstance(batch[0], dict):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], int):return torch.tensor(batch) # 将 int 转换为 Tensorraise TypeError(fUnsupported type: {type(batch[0])})default_collate 函数通过递归处理不同类型的样本张量、列表、元组、字典、整数等将零散的单个样本组合成统一的批量数据格式确保批量数据能被模型正确接收和处理同时处理不同类型的数据结构。。 处理张量Tensor类型 if isinstance(batch[0], torch.Tensor):return torch.stack(batch, dim0)如果样本是 torch.Tensor如图像的像素数据则使用 torch.stack 沿着第 0 维度拼接形成一个包含批量数据的新张量。 例如32 个形状为 (3, 224, 224) 的图像张量会被拼接成 (32, 3, 224, 224) 的批量张量。 处理列表/元组list/tuple类型 elif isinstance(batch[0], (list, tuple)):return [default_collate(samples) for samples in zip(*batch)]如果样本是列表或元组如包含多个输入特征的情况则通过 zip(*batch) 按位置拆分批量数据再递归调用 default_collate 处理每个位置的子数据。 例如每个样本是 (图像张量, 标签) 的元组批量处理后会得到 (批量图像张量, 批量标签) 的元组。 处理字典dict类型 elif isinstance(batch[0], dict):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}如果样本是字典如包含 {image: 图像张量, label: 标签} 的结构则按字典的键key分组对每个键对应的所有样本值递归调用 default_collate最终返回一个包含批量数据的新字典。 处理整数int类型 elif isinstance(batch[0], int):return torch.tensor(batch) # 将 int 转换为 Tensor如果样本是整数如分类任务的标签则将整个批量的整数转换为 torch.Tensor方便后续计算。 不支持的类型 raise TypeError(fUnsupported type: {type(batch[0])})若遇到上述类型之外的数据会抛出类型错误提示不支持该类型。
自定义 collate_fn 函数
常用需要自定义的场景一个 batch 中的 多张图片经过 dataset.getitem() 方法得到的图像输出尺寸不一样。可能因为 图像增强 使用 的 transforms 设计的 最后一步处理方式是范围内的随机裁剪
又因为网络要求输入数据的尺寸形式为 (batch_size channel highwidth) 为了将多张图像数据打包成一个batch 的数据形式
对比一个batch中所有图片的宽高找到最长的值。根据最大的作为标准给图像加上padding保证所有图像尺寸一致。处理得出masks数据记录每一个图片有效像素和padding像素的位置进而组成 batch 的数据形式进行返回。 Deformable-DETR/main.py有这个场景的代码实现
data_loader_train DataLoader(dataset_train, batch_samplerbatch_sampler_train,collate_fnutils.collate_fn, num_workersargs.num_workers,pin_memoryTrue)data_loader_val DataLoader(dataset_val, args.batch_size, samplersampler_val,drop_lastFalse, collate_fnutils.collate_fn, num_workersargs.num_workers,pin_memoryTrue)Deformable-DETR/util/misc.py
def collate_fn(batch):batch list(zip(*batch))batch[0] nested_tensor_from_tensor_list(batch[0])return tuple(batch)def _max_by_axis(the_list):# type: (List[List[int]]) - List[int]maxes the_list[0]for sublist in the_list[1:]:for index, item in enumerate(sublist):maxes[index] max(maxes[index], item)return maxesdef nested_tensor_from_tensor_list(tensor_list: List[Tensor]):# TODO make this more generalif tensor_list[0].ndim 3:# TODO make it support different-sized imagesmax_size _max_by_axis([list(img.shape) for img in tensor_list])# min_size tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))batch_shape [len(tensor_list)] max_sizeb, c, h, w batch_shapedtype tensor_list[0].dtypedevice tensor_list[0].devicetensor torch.zeros(batch_shape, dtypedtype, devicedevice)mask torch.ones((b, h, w), dtypetorch.bool, devicedevice)for img, pad_img, m in zip(tensor_list, tensor, mask):pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)m[: img.shape[1], :img.shape[2]] Falseelse:raise ValueError(not supported)return NestedTensor(tensor, mask)