国内网站开发公司,建设论坛网站自学,wordpress的搭建环境搭建,网站做收录是什么意思之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的… 这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中. 0. MMDe…
之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的… 这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中. 0. MMDetection整体组成部分
让我们首先回顾一下C的标准模板库(STL)是怎样设计的. STL的三个核心组件是容器, 算法与迭代器. 容器, 例如vector, queue等等, 他们是负责存储数据的, 算法是负责进行一些操作, 例如排序, 查找等等. 而迭代器是容器与算法之间的桥梁, 也就是算法可以通过迭代器去访问容器, 使得算法可以独立于容器的类型进行操作. 三个部分相辅相成, 就达到了泛型编程的理念.
再让我们回顾一下一套深度学习的代码包含什么部分. 从大的方面来说, 需要有数据的读取与增强(DataLoader), 模型的定义, 损失函数的计算, 负责梯度传播的优化器, 在验证(测试)集上的评估等. 同理, MMDetection也是按照这种方式来的, 并且每个部分接口相通, 就可以实现更广义的模型定义和训练方式.
在mmengine/registry/__init__.py中, 我们可以看到, MMEngine(或者说MMDetection)总体有这些类型的模块:
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS,INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS,MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,WEIGHT_INITIALIZERS)那么以上这么多模块可以分成几类, 分别负责什么呢? 按照我个人的理解, MMDetection的整体组成部分可以表示为下图: 为了节省空间, 优化器相关并未画出 1. 认识config文件
mmdetection设计的核心思想是通过字典来配置整个的训练过程和模型定义, 这些字典放在一个.py的config文件中. 一般来说config文件最重要的就是数据加载train_dataloader, val_dataloader和test_dataloader, 模型定义(model)和训练与测试过程(train_pipeline, test_pipeline). 除此之外, 还有一些训练, 测试配置(train_cfg, test_cfg)等等. 具体config的例子可以参照官网Learn about configs.
需要注意的是, mmdetection中字典定义class的方式, 往往是键type表示类的名字, 之后的其他键都是类初始化需要的参数. 例如, 如果我想自定义一个模型, 叫做MyModel, 定义在当前目录下的./models/my_model.py中, 定义方式如下: from mmdet.registry import MODELS # 自定义模型, 需要在模型库中注册, 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel # 一个模型基类MODELS.register_module() # 装饰器 在模型库中注册
class MyModel(BaseMOTModel):def __init__(self, arg1..., arg2..., arg3...):...def loss(self, inputs, data_samples): # 前向传播, inputs是输入tensor, data_samples是包含标签的列表...如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子: # 必须将自定义类的py文件导入 这样可以自动register自定义模型 否则模型初始化时找不到custom_imports dict(imports[models.my_model],allow_failed_importsFalse)# 现在就可以愉快的传参了
modelsdict(typeMyModel, arg11, arg2[16, 128], arg3dict(channel256), ...
)同样, 我们可以自定义DataLoader, Loss, 等等.
此外, dict是可以嵌套的, 例如mmdetection将检测模型分成了backbone, neck和head三部分, 那么如果我们又自定义了一个Head, 叫MyHead: from mmdet.registry import MODELS # 自定义模型, 需要在模型库中注册, 初始化时才能找到定义
from mmengine.model import BaseModule # 一个模型基类MODELS.register_module() # 装饰器 在模型库中注册
class MyHead(BaseModule):def __init__(self, arg4...):...
这样, 如果MyModel的前向传播过程中需要一个head, 则代码大致是这个样子: from mmdet.registry import MODELS # 自定义模型, 需要在模型库中注册, 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel # 一个模型基类MODELS.register_module() # 装饰器 在模型库中注册
class MyModel(BaseMOTModel):def __init__(self, arg1..., arg2..., arg3...,head...):self.head MODELS.build(head) # 建立Head的模型, 类型是nn.Module...def loss(self, inputs, data_samples): # 前向传播, inputs是输入tensor, data_samples是包含标签的列表... # 一些其他过程ret self.head(inputs) # forward... # 后处理
配置文件中对应更改为:
如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子: custom_imports dict(imports[models.my_model, 自定义HEAD所在的py文件],allow_failed_importsFalse)modelsdict(typeMyModel, arg11, arg2[16, 128], arg3dict(channel256), headdict( # 定义headtypeMyHead,arg4256,...)...
)篇幅所限, 自定义损失函数, 数据增强之类的就不一一列举了.
2. 数据流
我们接下来以检测与跟踪任务为例, 看看数据到底是如何被读入的. 我们以训练过程说明.
在训练过程中, 我们会初始化一个RUNNER类, 其读入我们的config文件并依次完成各种(模型, 数据加载, 优化器, 钩子等等)的初始化. 我们以官方提供的train.py为例:
runner Runner.from_cfg(cfg)from_cfg()是一个类方法(classmethod), 在其中我们实例化了Runner类.
随后, 我们调用Runner的train()方法进行训练. 首先, 我们实例化训练循环: self._train_loop self.build_train_loop(self._train_loop) # type: ignore
训练循环就属于LOOP类型.
在这里, 我们以最常用的EpochBasedTrainLoop为例. 在EpochBasedTrainLoop的初始化函数中, 根据config文件中的train_dataloader字典实例化出torch的DataLoader类(): data_loader DataLoader(datasetdataset,samplersampler if batch_sampler is None else None,batch_samplerbatch_sampler,collate_fncollate_fn,worker_init_fninit_fn,**dataloader_cfg)return data_loader
当然, 我们知道torch的DataLoader类在调用的时候, 会调用到dataset(类别是torch.utils.data.Dataset)的__getitem__方法. 因此, 我们从__getitem__入手来探索数据流.
在MMDetection的设计中, 数据集的类都是继承于MMengine中的BaseDataset, 其中的__getitem__是这样写的: def __getitem__(self, idx: int) - dict:if not self._fully_initialized:print_log(Please call full_init() method manually to accelerate the speed.,loggercurrent,levellogging.WARNING)self.full_init()if self.test_mode:data self.prepare_data(idx)if data is None:raise Exception(Test time pipline should not get None data_sample)return datafor _ in range(self.max_refetch 1):data self.prepare_data(idx)# Broken images or random augmentations may cause the returned data# to be Noneif data is None:idx self._rand_another()continuereturn dataraise Exception(fCannot find valid image after {self.max_refetch}! Please check your image path and pipeline)
我们可以看到, 在__getitem__中最核心的是self.prepare_data(idx). 按照这种思路一级一级向上查找, 我们就可以总结出如下图的数据读取流程: 其中, 数据增强pipeline是一系列类型为TRANSFORMS类的列表, 再每经过一次数据增强时, 字典都会被更新.
我们以较为常用的随机便宜(RandomShift)来说, 其是这样定义的: TRANSFORMS.register_module()
class RandomShift(BaseTransform):def __init__(self,...autocast_box_type()def transform(self, results: dict) - dict: # transform方法, 更新字典, 图像与对应的边界框等都需要被更新Transform function to random shift images, bounding boxes.Args:results (dict): Result dict from loading pipeline.Returns:dict: Shift results.if self._random_prob() self.prob:img_shape results[img].shape[:2]random_shift_x random.randint(-self.max_shift_px,self.max_shift_px)random_shift_y random.randint(-self.max_shift_px,self.max_shift_px)new_x max(0, random_shift_x)ori_x max(0, -random_shift_x)new_y max(0, random_shift_y)ori_y max(0, -random_shift_y)# TODO: support mask and semantic segmentation maps.bboxes results[gt_bboxes].clone()bboxes.translate_([random_shift_x, random_shift_y])# clip borderbboxes.clip_(img_shape)# remove invalid bboxesvalid_inds (bboxes.widths self.filter_thr_px).numpy() (bboxes.heights self.filter_thr_px).numpy()# If the shift does not contain any gt-bbox area, skip this# image.if not valid_inds.any():return resultsbboxes bboxes[valid_inds]results[gt_bboxes] bboxesresults[gt_bboxes_labels] results[gt_bboxes_labels][valid_inds]if results.get(gt_ignore_flags, None) is not None:results[gt_ignore_flags] \results[gt_ignore_flags][valid_inds]# shift imgimg results[img]new_img np.zeros_like(img)img_h, img_w img.shape[:2]new_h img_h - np.abs(random_shift_y)new_w img_w - np.abs(random_shift_x)new_img[new_y:new_y new_h, new_x:new_x new_w] \ img[ori_y:ori_y new_h, ori_x:ori_x new_w]results[img] new_imgreturn results需要注意的是, 经过pipeline后, 字典最终会被更新成如下形式:
dict {inputs: torch.Tensor, data_samples: DetDataSample或TrackDataSample等}
其中inputs键对应的值就是转换为tensor的图片, 而data_samples键对应的值是表示样本的类, 在检测任务中, 是DetDataSample, 跟踪任务中, 是TrackDataSample. DetDataSample类有许多成员, 包括该样本(图片)的目标的边界框真值, 分割真值等: class DetDataSample(BaseDataElement):A data structure interface of MMDetection. They are used as interfacesbetween different components.The attributes in DetDataSample are divided into several parts:- proposals(InstanceData): Region proposals used in two-stagedetectors.- gt_instances(InstanceData): Ground truth of instance annotations.- pred_instances(InstanceData): Instances of detection predictions.- pred_track_instances(InstanceData): Instances of trackingpredictions.- ignored_instances(InstanceData): Instances to be ignored duringtraining/testing.- gt_panoptic_seg(PixelData): Ground truth of panopticsegmentation.- pred_panoptic_seg(PixelData): Prediction of panopticsegmentation.- gt_sem_seg(PixelData): Ground truth of semantic segmentation.- pred_sem_seg(PixelData): Prediction of semantic segmentation.
以上过程可以借用MMEngine文档里的一个图说明: 最终, 模型的forward, loss, predict等方法都是接收inputs: torch.Tensor与data_samples作为输入, 例如: def loss(self, inputs: Tensor, data_samples: TrackSampleList,**kwargs) - Union[dict, tuple]: