佛山企业网站制作,洛阳网络科技有限公司排名,网站开发中的paml,做网站运用的技术timm 视觉库中的 create_model 函数详解
最近一年 Vision Transformer 及其相关改进的工作层出不穷#xff0c;在他们开源的代码中#xff0c;大部分都用到了这样一个库#xff1a;timm。各位炼丹师应该已经想必已经对其无比熟悉了#xff0c;本文将介绍其中最关键的函数之…timm 视觉库中的 create_model 函数详解
最近一年 Vision Transformer 及其相关改进的工作层出不穷在他们开源的代码中大部分都用到了这样一个库timm。各位炼丹师应该已经想必已经对其无比熟悉了本文将介绍其中最关键的函数之一create_model 函数。
timm简介
PyTorchImageModels简称timm是一个巨大的PyTorch代码集合包括了一系列
image modelslayersutilitiesoptimizersschedulersdata-loaders / augmentationstraining / validation scripts
旨在将各种 SOTA 模型、图像实用工具、常用的优化器、训练策略等视觉相关常用函数的整合在一起并具有复现ImageNet训练结果的能力。 源码https://github.com/rwightman/pytorch-image-models 文档https://fastai.github.io/timmdocs/ create_model 函数的使用及常用参数
本小节先介绍 create_model 函数及常用的参数 **kwargs。
顾名思义create_model 函数是用来创建一个网络模型如 ResNet、ViT 等timm 库本身可供直接调用的模型已有接近400个用户也可以自己实现一些模型并注册进 timm 这一部分内容将在下一小节着重介绍供自己调用。
model_name
我们首先来看最简单地用法直接传入模型名称 model_name
import timm
# 创建 resnet-34
model timm.create_model(resnet34)
# 创建 efficientnet-b0
model timm.create_model(efficientnet_b0)我们可以通过 list_models 函数来查看已经可以直接创建、有预训练参数的模型列表
all_pretrained_models_available timm.list_models(pretrainedTrue)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))输出
[..., vit_large_patch16_384, vit_large_patch32_224_in21k, vit_large_patch32_384, vit_small_patch16_224, wide_resnet50_2, wide_resnet101_2, xception, xception41, xception65, xception71]
452如果没有设置 pretrainedTrue 的话有将会输出612即有预训练权重参数的模型有452个没有预训练参数只有模型结构的共有612个。
pretrained
如果我们传入 pretrainedTrue那么 timm 会从对应的 URL 下载模型权重参数并载入模型只有当第一次即本地还没有对应模型参数时会去下载之后会直接从本地加载模型权重参数。
model timm.create_model(resnet34, pretrainedTrue)输出
Downloading: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pthfeatures_only、out_indices
create_mode 函数还支持 features_onlyTrue 参数此时函数将返回部分网络该网络提取每一步最深一层的特征图。还可以使用 out_indices[…] 参数指定层的索引以提取中间层特征。
# 创建一个 (1, 3, 224, 224) 形状的张量
x torch.randn(1, 3, 224, 224)
model timm.create_model(resnet34)
preds model(x)
print(preds shape: {}.format(preds.shape))all_feature_extractor timm.create_model(resnet34, features_onlyTrue)
all_features all_feature_extractor(x)
print(All {} Features: .format(len(all_features)))
for i in range(len(all_features)):print(feature {} shape: {}.format(i, all_features[i].shape))out_indices [2, 3, 4]
selected_feature_extractor timm.create_model(resnet34, features_onlyTrue, out_indicesout_indices)
selected_features selected_feature_extractor(x)
print(Selected Features: )
for i in range(len(out_indices)):print(feature {} shape: {}.format(out_indices[i], selected_features[i].shape))我们以一个 (1, 3, 224, 224) 形状的张量为输入在视觉任务中图像输入张量总是类似的形状。上面例程展示了创建完整模型 model创建完整特征提取器 all_feature_extractor和创建某几层特征提取器 selected_feature_extractor 的具体输出。
可以结合下面 ResNet34 的结构图来理解图中不同的颜色表示不同的 layer根据下图分析各层的卷积操作计算各层最后一个卷积的输入并与上面例程的输出附在图后验证是否一致。 输出
preds shape: torch.Size([1, 1000])
All 5 Features:
feature 0 shape: torch.Size([1, 64, 112, 112])
feature 1 shape: torch.Size([1, 64, 56, 56])
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
Selected Features:
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])这样我们就可以通过 timm_model 函数及其 features_only 、out_indices 参数将预训练模型方便地转换为自己想要的特征提取器。
接下来我们来看一下这些特征提取器究竟是什么类型
import timm
feature_extractor timm.create_model(resnet34, features_onlyTrue, out_indices[3])print(type:, type(feature_extractor))
print(len: , len(feature_extractor))
for item in feature_extractor:print(item)输出
type: class timm.models.features.FeatureListNet
len: 7
conv1
bn1
act1
maxpool
layer1
layer2
layer3可以看到feature_extractor 其实也是一个神经网络在 timm 中称为 FeatureListNet而我们通过 out_indices 参数来指定截取到哪一层特征。
需要注意的是ViT 模型并不支持 features_only 选项0.4.12版本。
extractor timm.create_model(vit_base_patch16_224, features_onlyTrue)输出
RuntimeError: features_only not implemented for Vision Transformer models.create_model 函数究竟做了什么
registry
在了解了 create_model 函数的基本使用之后我们来深入探索一下 create_model 函数的源码看一下究竟是怎样实现从模型到特征提取器的转换的。
create_model 主体只有 50 行左右的代码因此所有这些神奇的事情是在其他地方完成的。我们知道 timm.list_models() 函数中的每一个模型名字str实际上都是一个函数。以下代码可以测试这一点
import timm
import random
from timm.models import registrym timm.list_models()[-1]
print(m)
registry.is_model(m)输出
xception71
True实际上在 timm 内部有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说我们可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。
constuctor_fn registry.model_entrypoint(m)
print(constuctor_fn)输出
function timm.models.xception_aligned.xception71(pretrainedFalse, **kwargs)也有可能输出
function xception71 at 0x7fc0cba0eca0一样的。
如我们所见在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的timm 中的每一个模型都有着一个这样的构造函数。事实上内部的 _model_entrypoints 字典大概长这个样子
_model_entrypoints{
cspresnet50:function timm.models.cspnet.cspresnet50(pretrainedFalse, **kwargs),cspresnet50d: function timm.models.cspnet.cspresnet50d(pretrainedFalse, **kwargs),
cspresnet50w: function timm.models.cspnet.cspresnet50w(pretrainedFalse, **kwargs),
cspresnext50: function timm.models.cspnet.cspresnext50(pretrainedFalse, **kwargs),
cspresnext50_iabn: function timm.models.cspnet.cspresnext50_iabn(pretrainedFalse, **kwargs),
cspdarknet53: function timm.models.cspnet.cspdarknet53(pretrainedFalse, **kwargs),
cspdarknet53_iabn: function timm.models.cspnet.cspdarknet53_iabn(pretrainedFalse, **kwargs),
darknet53: function timm.models.cspnet.darknet53(pretrainedFalse, **kwargs),
densenet121: function timm.models.densenet.densenet121(pretrainedFalse, **kwargs),
densenetblur121d: function timm.models.densenet.densenetblur121d(pretrainedFalse, **kwargs),
densenet121d: function timm.models.densenet.densenet121d(pretrainedFalse, **kwargs),
densenet169: function timm.models.densenet.densenet169(pretrainedFalse, **kwargs),
densenet201: function timm.models.densenet.densenet201(pretrainedFalse, **kwargs),
densenet161: function timm.models.densenet.densenet161(pretrainedFalse, **kwargs),
densenet264: function timm.models.densenet.densenet264(pretrainedFalse, **kwargs),}所以说在 timm 对应的模块中每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此实际上我们有两种方式来创建一个 resnet34 模型
import timm
from timm.models.resnet import resnet34# 使用 create_model
m timm.create_model(resnet34)# 直接调用构造函数
m resnet34()但使用上我们无须调用构造函数。所用模型都可以通过 create_model 函数来将创建。
Register model
resnet34 构造函数的源码如下
register_model
def resnet34(pretrainedFalse, **kwargs):Constructs a ResNet-34 model.model_args dict(blockBasicBlock, layers[3, 4, 6, 3], **kwargs)return _create_resnet(resnet34, pretrained, **model_args)我们会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始 _model_entrypoints 是一个空字典。我们是通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下
def register_model(fn):# lookup containing modulemod sys.modules[fn.__module__]module_name_split fn.__module__.split(.)module_name module_name_split[-1] if len(module_name_split) else # add model to __all__ in modulemodel_name fn.__name__if hasattr(mod, __all__):mod.__all__.append(model_name)else:mod.__all__ [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] fn_model_to_module[model_name] module_name_module_to_models[module_name].add(model_name)has_pretrained False # check if model has a pretrained url to allow filtering on thisif hasattr(mod, default_cfgs) and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained url in mod.default_cfgs[model_name] and http in mod.default_cfgs[model_name][url]if has_pretrained:_model_has_pretrained.add(model_name)return fn我们可以看到 register_model 函数完成了一些比较基础的步骤但这里需要指出的是这一句
_model_entrypoints[model_name] fn它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__。所以说 resnet34 函数上的装饰器 register_model 在 _model_entrypoints 中创建一个新的条目像这样
{#8217;resnet34#8217;: function timm.models.resnet.resnet34(pretrainedFalse, **kwargs)}我们同样可以看到在 resnet34 构造函数的源码中在设置完一些 model_args 之后它会随后调用 _create_resnet 函数。让我们再来看一下该函数的源码
def _create_resnet(variant, pretrainedFalse, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfgdefault_cfgs[variant], pretrainedpretrained, **kwargs)所以在 _create_resnet 函数之中会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。
default config
timm 中所有的模型都有一个默认的配置包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。
resnet34 的默认配置如下
{url: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth,
num_classes: 1000,
input_size: (3, 224, 224),
pool_size: (7, 7),
crop_pct: 0.875,
interpolation: bilinear,
mean: (0.485, 0.456, 0.406),
std: (0.229, 0.224, 0.225),
first_conv: conv1,
classifier: fc}此默认配置与其他参数如构造函数类和一些模型参数一起传递给 build_model_with_cfg 函数。
build model with config
这个 build_model_with_cfg 函数负责
真正地实例化一个模型类来创建一个模型若 prunedTrue对模型进行剪枝若 pretrainedTrue加载预训练模型参数若 features_onlyTrue将模型转换为特征提取器
看一下该函数的源码
def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict None,feature_cfg: dict None,pretrained_strict: bool True,pretrained_filter_fn: Callable None,pretrained_custom_load: bool False,**kwargs):pruned kwargs.pop(pruned, False)features Falsefeature_cfg feature_cfg or {}if kwargs.pop(features_only, False):features Truefeature_cfg.setdefault(out_indices, (0, 1, 2, 3, 4))if out_indices in kwargs:feature_cfg[out_indices] kwargs.pop(out_indices)model model_cls(**kwargs) if model_cfg is None else model_cls(cfgmodel_cfg, **kwargs)model.default_cfg deepcopy(default_cfg)if pruned:model adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained 0 if features else getattr(model, num_classes, kwargs.get(num_classes, 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classesnum_classes_pretrained, in_chanskwargs.get(in_chans, 3),filter_fnpretrained_filter_fn, strictpretrained_strict)if features:feature_cls FeatureListNetif feature_cls in feature_cfg:feature_cls feature_cfg.pop(feature_cls)if isinstance(feature_cls, str):feature_cls feature_cls.lower()if hook in feature_cls:feature_cls FeatureHookNetelse:assert False, fUnknown feature class {feature_cls}model feature_cls(model, **feature_cfg)model.default_cfg default_cfg_for_features(default_cfg) # add back default_cfgreturn model我们可以看到模型在这一步被创建出来model model_cls(**kwargs)。本文将不再深入到 pruned 和 adapt_model_from_file 内部查看。
总结
通过本文我们已经完全了解了 create_model 函数我们了解到
每个模型有不同的构造函数可以传入不同的参数 _model_entrypoints 字典包括了所有的模型名称及其对应的构造函数build_with_model_cfg 函数接收模型构造器类和其中的一些具体参数真正地实例化一个模型load_pretrained 会加载预训练参数FeatureListNet 类可以将模型转换为特征提取器
Ref
https://github.com/rwightman/pytorch-image-models
https://fastai.github.io/timmdocs/
https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor
https://fastai.github.io/timmdocs/tutorial_feature_extractor
https://zhuanlan.zhihu.com/p/404107277