自动跳转手机网站代码,让别人做网站要注意什么,江都住房和建设局网站,云南哪几个建网站公司还记得这篇文章吗#xff1f;迁移学习|代码实现
在这篇文章中#xff0c;我们知道了在构建模型时#xff0c;可以借助一些非常有名的模型#xff0c;这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改#xf…还记得这篇文章吗迁移学习|代码实现
在这篇文章中我们知道了在构建模型时可以借助一些非常有名的模型这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改便可运用到自己的实际任务中 我们仍然按照这个步骤开始我们的模型的训练 准备一个可迭代的数据集 定义一个神经网络 将数据集输入到神经网络进行处理 计算损失 通过梯度下降算法更新参数
import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models
数据集准备
cifar10_train torchvision.datasets.CIFAR10( root cifar10/, train True, download True)cifar10_testtorchvision.datasets.CIFAR10( root cifar10/, train False, download True)
transform transforms.Compose([ transforms.ToTensor(), transforms.Resize((224,224)) ])cifar2_train[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader torch.utils.data.DataLoader(cifar2_train, batch_size64,shuffleTrue)test_loader torch.utils.data.DataLoader(cifar2_test, batch_size64,shuffleTrue)数据集使用CIFAR-10数据集中的猫和狗。
CIFAR-10数据集类别
种类 标签 plane 0 car 1 bird 2 cat 3 deer 4 dog 5 frog 6 horse 7 ship 8 truck 9 可以看到其中cat和dog的标签分别为3和5
借助
[3,5].index(label)
我们可以将cat标签变为0dog标签变为1从而回到二分类问题。 举个例子 [3,5].index(3)0 [3,5].index(5)1 定义模型
参考这篇文章迁移学习|代码实现
#网络搭建networkmodels.resnet18(pretrainedTrue)
for param in network.parameters(): param.requires_gradFalse
network.fcnn.Linear(512,2)#损失函数criterionnn.CrossEntropyLoss()#优化器optimizeroptim.SGD(network.fc.parameters(),lr0.01,momentum0.9)
devicetorch.device(cuda if torch.cuda.is_available() else cpu)networknetwork.to(device) 训练模型
for epoch in range(10): total_loss 0 total_correct 0 for batch in train_loader: # Get batch images, labels batch imagesimages.to(device) labelslabels.to(device) optimizer.zero_grad() #告诉优化器把梯度属性中权重的梯度归零否则pytorch会累积梯度 preds network(images) loss criterion(preds, labels) loss.backward() optimizer.step() total_loss loss.item() _,prelabelstorch.max(preds,dim1) total_correct int((prelabelslabels).sum()) accuracy total_correct/len(cifar2_train) print(Epoch:%d , Loss:%f , Accuracy:%f %(epoch,total_loss,accuracy)) Epoch:0 , Loss:78.549439 , Accuracy:0.788900 Epoch:1 , Loss:77.828066 , Accuracy:0.801500 Epoch:2 , Loss:66.151785 , Accuracy:0.828100 Epoch:3 , Loss:76.204446 , Accuracy:0.816800 Epoch:4 , Loss:68.886606 , Accuracy:0.828100 Epoch:5 , Loss:71.129405 , Accuracy:0.821200 Epoch:6 , Loss:66.096364 , Accuracy:0.829900 Epoch:7 , Loss:65.504227 , Accuracy:0.827700 Epoch:8 , Loss:76.303878 , Accuracy:0.817100 Epoch:9 , Loss:70.546953 , Accuracy:0.820700 测试模型
correct0total0network.eval()with torch.no_grad(): for batch in test_loader: imgs,labelsbatch imgsimgs.cuda() labelslabels.cuda() predsnetwork(imgs) _,prelabelstorch.max(preds,dim1) #print(prelabels.size()) totaltotallabels.size(0) correctcorrectint((prelabelslabels).sum()) #print(total) accuracycorrect/total print(Accuracy: ,accuracy)
Accuracy: 0.8025
这里使用的预训练模型是resnet18,我们也可以使用VGG16模型同时记得改变最后一个全连接层的输出参数使得其满足我们自己的任务。 除了预训练模型之外我们还可以对一些超参数进行调整使最后的效果变得更好