智博常州网站建设,wordpress缩略图只生成full,武安市城乡建设局网站,最全网站源码分享知识点回顾#xff1a; 预训练的概念常见的分类预训练模型图像预训练模型的发展史预训练的策略预训练代码实战#xff1a;resnet18 作业#xff1a; 尝试在cifar10对比如下其他的预训练模型#xff0c;观察差异#xff0c;尽可能和他人选择的不同尝试通过ctrl进入resnet的… 知识点回顾 预训练的概念常见的分类预训练模型图像预训练模型的发展史预训练的策略预训练代码实战resnet18 作业 尝试在cifar10对比如下其他的预训练模型观察差异尽可能和他人选择的不同尝试通过ctrl进入resnet的内部观察残差究竟是什么 一、在 CIFAR10 上对比如下其他的预训练模型
可以选择不同的预训练模型如 VGG16、Inception V3 等对比它们在 CIFAR10 数据集上的训练时间、准确率等指标。以下是使用 VGG16 的示例代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16# 数据预处理
transform transforms.Compose([transforms.Resize((224, 224)), # Inception 和 VGG 要求输入图像大小为 224x224transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载 CIFAR10 数据集
trainset torchvision.datasets.CIFAR10(root./data, trainTrue,downloadTrue, transformtransform)
trainloader torch.utils.data.DataLoader(trainset, batch_size4,shuffleTrue, num_workers2)testset torchvision.datasets.CIFAR10(root./data, trainFalse,downloadTrue, transformtransform)
testloader torch.utils.data.DataLoader(testset, batch_size4,shuffleFalse, num_workers2)# 加载预训练的 VGG16 模型
model vgg16(pretrainedTrue)
num_ftrs model.classifier[6].in_features
model.classifier[6] nn.Linear(num_ftrs, 10) # 修改最后一层全连接层以适应 CIFAR10 的 10 个类别# 定义损失函数和优化器
criterion nn.CrossEntropyLoss()
optimizer optim.SGD(model.parameters(), lr0.001, momentum0.9)# 训练模型
device torch.device(cuda:0 if torch.cuda.is_available() else cpu)
model.to(device)for epoch in range(2): # 训练 2 个 epochrunning_loss 0.0for i, data in enumerate(trainloader, 0):inputs, labels data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs model(inputs)loss criterion(outputs, labels)loss.backward()optimizer.step()running_loss loss.item()if i % 2000 1999: # 每 2000 个 mini-batches 打印一次print(f[{epoch 1}, {i 1:5d}] loss: {running_loss / 2000:.3f})running_loss 0.0print(Finished Training) 二、尝试通过 ctrl 进入 ResNet 的内部观察残差究竟是什么
在 PyTorch 中如果你使用的是 PyCharm 等 IDE可以按住 Ctrl 键并点击 resnet18 函数进入 torchvision.models.resnet 模块。在该模块中可以找到 BasicBlock 类它实现了 ResNet 的残差块。
class BasicBlock(nn.Module):expansion 1def __init__(self, inplanes, planes, stride1, downsampleNone, groups1,base_width64, dilation1, norm_layerNone):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer nn.BatchNorm2dif groups ! 1 or base_width ! 64:raise ValueError(BasicBlock only supports groups1 and base_width64)if dilation 1:raise NotImplementedError(Dilation 1 not supported in BasicBlock)# Both self.conv1 and self.downsample layers downsample the input when stride ! 1self.conv1 conv3x3(inplanes, planes, stride)self.bn1 norm_layer(planes)self.relu nn.ReLU(inplaceTrue)self.conv2 conv3x3(planes, planes)self.bn2 norm_layer(planes)self.downsample downsampleself.stride stridedef forward(self, x):identity xout self.conv1(x)out self.bn1(out)out self.relu(out)out self.conv2(out)out self.bn2(out)if self.downsample is not None:identity self.downsample(x)out identity # 这一行实现了残差连接out self.relu(out)return out
在 forward 方法中 out identity 这一行实现了残差连接。 identity 是输入的原始特征图 out 是经过两层卷积和批量归一化处理后的特征图将它们相加后再通过 ReLU 激活函数使得模型可以学习到输入和输出之间的残差信息。