网站跳出率因素,建设单位应该关注的网站,千锋教育地址,网站建设 网页开发一、导包
import torch
import torchvision
import torchvision.transforms as transforms
二、下载数据集
2.1 代码展示
# 定义数据加载进来后的初始化操作#xff1a;
transform transforms.Compose([# 张量转换#xff1a;transforms.ToTensor(),# 归一化操作#x…一、导包
import torch
import torchvision
import torchvision.transforms as transforms
二、下载数据集
2.1 代码展示
# 定义数据加载进来后的初始化操作
transform transforms.Compose([# 张量转换transforms.ToTensor(),# 归一化操作transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])trainset torchvision.datasets.CIFAR10(root./data,trainTrue,downloadTrue,transformtransform)
testset torchvision.datasets.CIFAR10(root./data,trainFalse,downloadTrue,transformtransform)
trainloader torch.utils.data.DataLoader(trainset,batch_size4,shuffleTrue,num_workers0)
testloader torch.utils.data.DataLoader(testset,batch_size4,shuffleFalse,num_workers0) 2.2 数据集介绍与下载方式 1数据集介绍 CIFAR10数据集共60000个样本其中有50000个训练样本和10000个测试样本。每个样本都是一张32*32像素的RGB图像彩色图像每个图像分为3个通道R通道、G通道与B通道。 CIFAR10数据集用来进行监督学习训练每个样本包含一个标签值其中有10类物体标签值按照0~9来区分分别是飞机 airplane 、汽车 automobile 、鸟 bird 、猫 cat 、鹿 deer 、狗 dog 、青蛙 frog 、马 horse 、船 ship 和卡车 truck 。 CIFAR10数据集的内容如下图所示。 官网介绍链接CIFAR-10 and CIFAR-100 datasets (toronto.edu) 2下载方式 ①下载文件下载地址https://pan.baidu.com/s/1Nh28RyfwPNNfe_sS8NBNUA 提取码1h4x ②将下载好的文件重命名为cifar-10-batches-py.tar.gz ③将文件保存至相应地址下即可
2.3 transforms.Compose transforms.Compose相当于将所有需要的操作进行打包 transforms.ToTensor完成张量转换pytorch处理的都是tensor数据需要将读入的图片转换为tensor其中tensor比普通图片的三通道多了一个通道—batch transforms.Normalize归一化操作对这一批次的数据进行归一可以加速网络的收敛、放置梯度消失与梯度爆炸。
2.4 Dataset Dataset是指定义好数据的格式和数据变换的形式完成一些初始化的变化然后送给网络相当于将数据读入进去。 torchvision.datasets.CIFAR10调用数据集第一个参数为数据集加载的地址、第二个参数为是否是训练数据或测试数据训练数据为True测试数据为False、第三个为download-指该数据集是否本地下载最后一个参数为要做哪些变化transform是指数据变换格式torchvision会将我们需要的数据进行格式变换。
2.4 Dataloader Dataloader负责用iterative迭代的方式不断读入批次数据一批次一批次将数据进行打包送入网络进行训练、学习、测试。 torch.utils.data.DataLoader第一个参数为数据第二个参数为batch_size(代表Dataloader一次从这么多数据中拿多少个数据走)第三个参数为是否将数据打乱训练的时候将数据打乱会让数据变得复杂测试的时候不需要变得复杂第四个参数为需要几线程进行读取数据对于windows默认为0就可以
三、定义元组 定义元组进行类别名的中文转换
classes (airplane,automobile,bird,car,deer,dog,frog,horse,ship,truck)
四、定义显示函数与运行数据加载器
4.1 代码展示
import matplotlib.pyplot as plt
import numpy as np # 用这个包中的根据将tensor数据转换成矩阵数据def imshow(img):img img / 2 0.5# tensor数据转换为numpynpimg img.numpy()# 使用transpose进行数据转换-通道转换plt.imshow(np.transpose(npimg,(1,2,0)))plt.show()dataiter iter(trainloader)
images,labels dataiter.next()imshow(torchvision.utils.make_grid(images))print(labels)
print(labels[0],classes[labels[0]])
print( .join(classes[labels[j]] for j in range(4))) 4.2 定义显示函数 tensor[batch,channel,H,W]而正常显示图片的顺序为H、W、channel因此需要定义显示函数通过反归一化才能变成正常的图片去显示。
4.3 定义迭代器 iter(trainloader) 定义迭代器读一次迭代的数据(batch_size4所以迭代器一次会读取四张图片) torchvision.utils.make_grid将多张图片拼接为一张图片。
参考003 第一个分类任务1_哔哩哔哩_bilibili