如何防止网站被注入黑链,我国酒店网站建设存在的问题,软件用户界面设计,网站建设只有一个空间吗在学习VIT之前#xff0c;建议先把 Transformer 搞明白了#xff1a;【transformer】入门与理解
做了那些改进#xff1f; 看图就比较明白了#xff0c;VIT只用了Encoder的部分#xff0c;把每一个图片裁剪成若干子图#xff0c;然后把一个子图flatten一下#xff0c;…在学习VIT之前建议先把 Transformer 搞明白了【transformer】入门与理解
做了那些改进 看图就比较明白了VIT只用了Encoder的部分把每一个图片裁剪成若干子图然后把一个子图flatten一下当成nlp中的一个token处理。 值得注意的是在首个 token中嵌入了一个 class_token维度为1embed_dim768这个class_token在预测的时候比较有意思见下图 注意上图中有些细节遗漏全流程应该是先把输入进行 patch_embedding 变成 visual tokens然后和 class_token 合并最后 position_embedding。
另外需要注意的是class_token 是一个可学习的参数并不是每次输入时都需要输入的类别数值。
self.class_token nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)代码
其实有了 Transformer 的基础后直接看代码就知道VIT是怎么做的了。
import copy
import torch
import torch.nn as nn# 所有基于nn.Module结构的模版可以删掉
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass Mlp(nn.Module):def __init__(self, embed_dim, mlp_ratio, dropout0.):super().__init__()self.fc1 nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增self.fc2 nn.Linear(int(embed_dim * mlp_ratio), embed_dim)self.act nn.GELU()self.dropout nn.Dropout(dropout)def forward(self, x):# TODOx self.fc1(x)x self.act(x)x self.dropout(x)x self.fc2(x)x self.dropout(x)return xclass PatchEmbedding(nn.Module):def __init__(self, image_size224, patch_size16, in_channels3, embed_dim768, dropout0.):super().__init__()n_patches (image_size // patch_size) * (image_size // patch_size) # 196 个 patchself.patch_embedding nn.Conv2d(in_channelsin_channels, # embedding 操作后变成 torch.Size([10, 768, 14, 14])out_channelsembed_dim,kernel_sizepatch_size,stridepatch_size)self.dropout nn.Dropout(dropout)# TODO: add class tokenself.class_token nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)# TODO: add position embeddingself.position_embedding nn.Parameter(torch.ones(1, n_patches1, embed_dim) * 0.98) #(1,1961,768)def forward(self, x): # 先把 x patch_embedding然后和 class_token 合并最后 position_embedding# [n, c, h, w]cls_tokens self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_tokenx self.patch_embedding(x) # [n, embed_dim, h, w]x x.flatten(2) # torch.Size([10, 768, 196])x x.permute([0, 2, 1]) # torch.Size([10, 196, 768])x torch.concat([cls_tokens, x], axis1) # (10,1961,768)x x self.position_embeddingreturn x # torch.Size([10, 197, 768])class Attention(nn.Module):multi-head self attentiondef __init__(self, embed_dim, num_heads, qkv_biasTrue, dropout0., attention_dropout0.):super().__init__()self.num_heads num_headsself.head_dim int(embed_dim / num_heads) # 768/4192self.all_head_dim self.head_dim * num_headsself.scales self.head_dim ** -0.5self.qkv nn.Linear(embed_dim,self.all_head_dim * 3) # [768, 768*3]self.proj nn.Linear(embed_dim, embed_dim)self.dropout nn.Dropout(dropout)self.attention_dropout nn.Dropout(attention_dropout)self.softmax nn.Softmax()def transpose_multihead(self, x):# x: [N, num_patches 197, all_head_dim 768] - [N, n_heads, num_patches, head_dim]new_shape [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]x x.reshape(new_shape) x x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]return xdef forward(self, x): # Attention 前后输入输出维度不变都是 [10, 197, 768]B, N, _ x.shape # torch.Size([10, 197, 768])qkv self.qkv(x).chunk(3, axis-1) # 含有三个元素的列表每一个元素大小 [10, 197, 768]q, k, v map(self.transpose_multihead, qkv) # [10, 4, 197, 192]attn torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]attn attn * self.scalesattn self.softmax(attn)attn self.attention_dropout(attn)out torch.matmul(attn, v) # [10, 4, 197, 192]out out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]out out.reshape([B, N, -1]) # [10, 197, 768]out self.proj(out) # [10, 197, 768]out self.dropout(out)return outclass EncoderModule(nn.Module):def __init__(self, embed_dim768, num_heads4, qkv_biasTrue, mlp_ratio4.0, dropout0., attention_dropout0.):super().__init__()self.attn_norm nn.LayerNorm(embed_dim)self.attn Attention(embed_dim, num_heads)self.mlp_norm nn.LayerNorm(embed_dim)self.mlp Mlp(embed_dim, mlp_ratio)def forward(self, x):h x # residualx self.attn_norm(x)x self.attn(x)x x hh x # residualx self.mlp_norm(x)x self.mlp(x)x x hreturn xclass Encoder(nn.Module):def __init__(self, embed_dim, depth):super().__init__()Module_list []for i in range(depth):encoder_Module EncoderModule()Module_list.append(encoder_Module)self.Modules nn.ModuleList(Module_list)self.norm nn.LayerNorm(embed_dim)def forward(self, x):for Module in self.Modules:x Module(x)x self.norm(x)return xclass VisualTransformer(nn.Module):def __init__(self,image_size224,patch_size16,in_channels3,num_classes1000,embed_dim768,depth3,num_heads8,):super().__init__()self.patch_embedding PatchEmbedding(image_size, patch_size, in_channels, embed_dim)self.encoder Encoder(embed_dim, depth)self.classifier nn.Linear(embed_dim, num_classes)def forward(self, x):# x:[N, C, H, W]x self.patch_embedding(x) # torch.Size([10, 197, 768])x self.encoder(x) # torch.Size([10, 197, 768])x self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦参考 x torch.concat([cls_tokens, x], axis1) # (10,1961,768)return xvit VisualTransformer()
print(vit)input_data torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])