asp网站建设中期报告,网吧手机网站模版,网站管理工作是具体应该怎么做,中国工程建设信息网官网查询文章目录致谢2 数据集的加载2.1 框架数据集的加载2.2 自定义数据集2.3 准备数据以进行数据加载器训练致谢 Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_pytorch自带数据集 2 数据集的加载
与sklearn中的datasets自带数据集类似#xff0c;pytorch框架也为我们提供了数…
文章目录致谢2 数据集的加载2.1 框架数据集的加载2.2 自定义数据集2.3 准备数据以进行数据加载器训练致谢 Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_pytorch自带数据集 2 数据集的加载
与sklearn中的datasets自带数据集类似pytorch框架也为我们提供了数据集以便一系列的模型测试。其数据集作为一个类继承自父类torch.utils.data.Dataset。
2.1 框架数据集的加载
让我们看看torch为我们提供了什么数据集。数据集种类如下所示 手写字符识别EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot 实物分类Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet 人脸识别CelebA 场景分类LSUN、Places365 用于object detectionSVHN、VOCDetection、COCODetection 用于semantic/instance segmentation 语义分割Cityscapes、VOCSegmentation 语义边界SBD 用于image captioningFlickr、COCOCaption 用于video classificationHMDB51、Kinetics 用于3D reconstructionPhotoTour 用于shadow detectorsSBU
以FashionMNIST数据集为例我们看一下如何加载数据集。 torch.datasets.FashionMNIST(root “data”,train True,download True,transform ToTensor()) root是存储训练/测试数据的路径train指定训练或测试数据集当布尔值为True则为训练集当布尔值为False则为测试集downloadTrue从互联网下载数据如果无法在本地获得transform指定特征转换方式target_transform指定标签转换方式 import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensordef load_data():加载数据集# 1 训练数据集的加载train_data datasets.FashionMNIST(rootdata,trainTrue,downloadTrue,transformToTensor())# 2 测试数据集的加载test_data datasets.FashionMNIST(rootdata,trainFalse,downloadTrue,transformToTensor())return train_data, test_datatrain_data, test_data load_data()
print(train_data)数据集加载完实际上是以类的形式存在的其不同于sklearn中返回的Bunch。
如果我们想要看看数据集中有啥要怎么做呢首先这个数据集是图像分类数据集说明里面含有的都是图像为此我们可以使用subplots存放这些图片。对于这些数据集我们可以像列表一样手动索引。如train_data[index]。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as pltdef load_data():加载数据集# 1 训练数据集的加载train_data datasets.FashionMNIST(rootdata,trainTrue,downloadTrue,transformToTensor())# 2 测试数据集的加载test_data datasets.FashionMNIST(rootdata,trainFalse,downloadTrue,transformToTensor())return train_data, test_datadef show_data(train_data):数据集可视化label_map {0: T_Shirt,1: Trouser,2: Pullover,3: Dress,4: Coat,5: Sandal,6: Shirt,7: Sneaker,8: Bag,9: Ankle Boot,}figure plt.figure(figsize(8, 8))cols, rows 3, 3# 从训练集中随机抽出九张图九个样本for i in range(1, cols * rows 1):# 设置索引索引取值为0到训练集的长度sample_idx torch.randint(len(train_data), size(1,)).item()# 取出对应样本的图片和标签img, label train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis(off)# 展示图片plt.imshow(img.squeeze(), cmapgray)# 释放画布plt.show()train_data, test_data load_data()
show_data(train_data)
out 上面用到了一个API即torch.randint() torch.randint(low0 high size generatorNone outNone dtypeNone layouttorch.strided deviceNone requires_gradFalse → Tensor 用于取随机整数返回值为张量lowint类型表明要从分布中提取的最低整数highint类型表明要从分布中提取的最高整数1size元组类型表明输出张量的形状dtype返回值张量的数据类型device返回张量所需的设备requires_grad布尔类型表明是否要对返回的张量自动求导。 如 torch.randint(3, 5, (3,))
tensor([4, 3, 4])意味生成一个一维的3元素向量其中向量中的元素取值从3-5取。 2.2 自定义数据集
如果你不想使用框架自带的数据集那么你可以自己定义一个数据集类。自定义Dataset类必须实现三个函数__ init __ 、 __ len __ 、__ getitem __。其中图像部分存储于一个文件夹中标签单独存储在CSV文件中。
在接下来的代码中让我们看看如何创建一个自定义数据集。
import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transformNone, target_transformNone):self.img_labels pd.read_csv(annotations_file)self.img_dir img_dirself.transform transformself.target_transform target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image read_image(img_path)label self.img_labels.iloc[idx, 1]if self.transform:image self.transform(image)if self.target_transform:label self.target_transform(label)return image, label对于__ init __ 函数来说包含加载图像、注释文件和两个转换的目录在这里我们不做过多讲解后面会详细介绍。
def __init__(self, annotations_file, img_dir, transformNone, target_transformNone):self.img_labels pd.read_csv(annotations_file)self.img_dir img_dirself.transform transformself.target_transform target_transform对于__ len __ 函数其功能是返回数据集中的样本数。
def __len__(self):return len(self.img_labels)对于 __ getitem __其功能是给定索引便能返回对应样本。
def __getitem__(self, idx):img_path os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image read_image(img_path)label self.img_labels.iloc[idx, 1]if self.transform:image self.transform(image)if self.target_transform:label self.target_transform(label)return image, label在自定义这一部分不用过多的去了解用着用着就会了就算不会代码也是通用需要用的时候看一下复制一下别搞得自己这么焦虑。
2.3 准备数据以进行数据加载器训练
在pytorch中数据加载的核心实际上是torch.utils.data.DataLoader类它支持对torch数据集的python可迭代换而言之DataLoader相当于你拿一个水盆而dataset相当于泉水。DataLoader可以对小批量数据集进行处理处理内容包括
地图样式和可迭代样式的数据集自定义数据集加载顺序多进程加载数据自动内存固定
其中地图样式数据集是指自定义数据集而可迭代样式数据集指的是自带数据集。其他详情对于初学者来说很不友好这里不做过多解释你可以理解为这就是个科普知识。
我们来看一下这个API吧。 torch.utils.data.DataLoader(数据集 batch_size1 shuffleFalse) 用于加载样本并且进行批处理数据集要加载的数据集batch_size整数类型表明每批要加载的样本数默认为1shuffle布尔类型表明是否要洗牌 我们利用上面的API来加载我们上面的Fashion_MNIST吧。
def load_batch_data():数据集批处理加载器train_dataloader DataLoader(train_data, batch_size64, shuffleTrue)test_dataloader DataLoader(test_data, batch_size64, shuffleTrue)return train_dataloader, test_dataloader既然已经将样本导入加载器那么我们如何从加载器中读取数据呢我们可以根据需要循环访问数据集。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoaderdef load_data():加载数据集# 1 训练数据集的加载train_data datasets.FashionMNIST(rootdata,trainTrue,downloadTrue,transformToTensor())# 2 测试数据集的加载test_data datasets.FashionMNIST(rootdata,trainFalse,downloadTrue,transformToTensor())return train_data, test_datadef show_data(train_data):数据集可视化label_map {0: T_Shirt,1: Trouser,2: Pullover,3: Dress,4: Coat,5: Sandal,6: Shirt,7: Sneaker,8: Bag,9: Ankle Boot,}figure plt.figure(figsize(8, 8))cols, rows 3, 3# 从训练集中随机抽出九张图九个样本for i in range(1, cols * rows 1):# 设置索引索引取值为0到训练集的长度sample_idx torch.randint(len(train_data), size(1,)).item()# 取出对应样本的图片和标签img, label train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis(off)# 展示图片plt.imshow(img.squeeze(), cmapgray)# 释放画布plt.show()def load_batch_data():数据集批处理加载器train_dataloader DataLoader(train_data, batch_size64, shuffleTrue)test_dataloader DataLoader(test_data, batch_size64, shuffleTrue)return train_dataloader, test_dataloaderdef show_batch_data():循环访问数据加载器train_dataloader, test_dataloader load_batch_data()train_feature, train_labels next(iter(train_dataloader))print(f特征大小{train_feature.size()})print(f标签大小{train_labels.size()})img train_feature[0].squeeze()label train_labels[0]plt.imshow(img, cmapgray)plt.show()print(flabel:{label})train_data, test_data load_data()
# show_data(train_data)
show_batch_data()