当前位置: 首页 > news >正文

住房与城乡建设部网站卖文章的网站源码

住房与城乡建设部网站,卖文章的网站源码,专门做动漫的网站有哪些,网站产品展示代码#x1f349; 博主微信 cvxiayixiao 还有其他专栏点击头像查询 #x1f353; 【Segment Anything Model】计算机视觉检测分割任务专栏。 #x1f351; 【公开数据集预处理】特别是医疗公开数据集的接受和预处理#xff0c;提供代码讲解。 #x1f348; 【opencv图像处理】… 博主微信 cvxiayixiao 还有其他专栏点击头像查询 【Segment Anything Model】计算机视觉检测分割任务专栏。 【公开数据集预处理】特别是医疗公开数据集的接受和预处理提供代码讲解。 【opencv图像处理】opencv代码库讲解结合图像处理知识不仅仅是调库。 文章目录 1️⃣预备知识EfficientSAM要解决的问题EfficientSAM解决此问题的创新点知识蒸馏和利用掩码图像去做预训练架构解释上半部分作用上半部分采用的方法下半部分作用EfficientSAM的结果 2️⃣EfficientSAM用于自己的数据集代码处理数据集将Efficient-SAM代码和权重拷贝到服务器或者本地去官网git下载权重愉快训练单卡多卡 1️⃣预备知识 EfficientSAM要解决的问题 sam本身架构庞大训练和推理都很慢。 EfficientSAM解决此问题的创新点 知识蒸馏和利用掩码图像去做预训练 具体来说 利用掩膜图像预训练SAMI来学习从SAM图像编码器中重构特征以有效进行视觉表示学习。这是提高EfficientSAMs效率和性能的核心策略。蒸馏到轻量级图像编码器和掩码解码器: 采用SAMI预训练的轻量级图像编码器和掩码解码器构建EfficientSAMs进一步降低模型复杂度同时保持良好的性能。在SA-1B数据集上的微调: 经过简化的架构在SA-1B上微调。还可以继续做下游任务包括图像分类、对象检测、实例分割和语义对象检测。 架构解释 SAMI预训练上半部分在ImageNet上进行而SAM微调下半部分则在SA-1B数据集上进行。 上半部分作用 上半部分是为了得到轻量级编码器因为sam本身的笨重的编码器是架构复杂推理慢的根本原因。 上半部分采用的方法 图像掩码预训练和损失重建。轻量级编码器学习重构来自于SAM VIT-H的图像编码器的特征嵌入这样做可以使得很小的编码器也能输出和SAM VIT-H一样的编码特征。这个过程算是只是蒸馏小模型学习目标是大模型的软输出一个较大的“教师”模型向一个较小的“学生”模型传递知识。 下半部分作用 上半部分训练了一个和VIT-H一样表现的模型但是他并不具备segment anything的能力meta的好多产品都是数据赋予模型能力所以还是得加大数据集之后在小模型上用大数据训练赋能。同时训练了提示解码器这里的提示可以点可以框。 EfficientSAM的结果 EfficientSAM-S 将 SAM 的推理时间减少了约 20 倍参数大小减少了约 20 倍性能略有下降为 44.4 AP sam为 46.5 AP。 同时与近期工作对比中约4AP高于MobileSAM FastSAM而且处理复杂度差不多。 2️⃣EfficientSAM用于自己的数据集代码 微调的方式太多啦比如只需要其编码器作为提特征的主干模块后面接自己的分类分割下游任务这个我也写了但是暂时不考虑公开。比如需要全部的模型架构但是自己数据集的域不同所以要全量微调。比如将sam的主干和其他主干做模型融合/特征融合。方式很多我们这里介绍最简单的也是其他一切fashion变形的基础。 我们这里介绍全量微调没加任何的fine-tune方法只是全量更新模型参数适配自己的数据集。 提示选择的一个框提示。 训练方式有单卡和多卡。 处理数据集 每个人的数据存储方式和格式都不一样前面几篇有处理成npy的代码可以参考上一篇有根据路径编写Dataset.py的代码可以参考。因为我的数据量比较大需要一个高效的存取方式没过一万长的随便存就行自己写个Dataset.py 能输出img和label张量就行。 这里给到一个npy存储格式构造Dataset的代码以供参考实际根据自己的数据集情况写一个能对上模型输入就行。 特别注意一下img label bboxes boxes_labels的维度就可以 class NpyDataset(Dataset):def __init__(self, data_root, image_size256, bbox_shift5, data_augTrue):self.data_root data_rootself.gt_path join(data_root, gts)self.img_path join(data_root, imgs)self.gt_path_files sorted(glob(join(self.gt_path, *.npy), recursiveTrue))self.gt_path_files [file for file in self.gt_path_filesif isfile(join(self.img_path, basename(file)))]self.image_size image_sizeself.target_length image_sizeself.bbox_shift bbox_shiftself.data_aug data_augdef __len__(self):return len(self.gt_path_files)def __getitem__(self, index):img_name basename(self.gt_path_files[index])assert img_name basename(self.gt_path_files[index]), img gt name error self.gt_path_files[index] \self.npy_files[index]img_3c np.load(join(self.img_path, img_name), r, allow_pickleTrue) # (H, W, 3)img_resize self.resize_longest_side(img_3c)# Resizingimg_resize (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min1e-8,a_maxNone) # normalize to [0, 1], (H, W, 3img_padded self.pad_image(img_resize) # (256, 256, 3)# convert the shape to (3, H, W)img_padded np.transpose(img_padded, (2, 0, 1)) # (3, 256, 256)assert np.max(img_padded) 1.0 and np.min(img_padded) 0.0, image should be normalized to [0, 1]gt np.load(self.gt_path_files[index], r, allow_pickleTrue) # multiple labels [0, 1,4,5...], (256,256)gt cv2.resize(gt,(img_resize.shape[1], img_resize.shape[0]),interpolationcv2.INTER_NEAREST).astype(np.uint8)gt self.pad_image(gt) # (256, 256)label_ids np.unique(gt)[1:]try:gt2D np.uint8(gt random.choice(label_ids.tolist())) # only one label, (256, 256)except:print(img_name, label_ids.tolist(), label_ids.tolist())gt2D np.uint8(gt np.max(gt)) # only one label, (256, 256)# add data augmentation: random fliplr and random flipudif self.data_aug:if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-1))gt2D np.ascontiguousarray(np.flip(gt2D, axis-1))# print(DA with flip left right)if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-2))gt2D np.ascontiguousarray(np.flip(gt2D, axis-2))# print(DA with flip upside down)gt2D np.uint8(gt2D 0)y_indices, x_indices np.where(gt2D 0)x_min, x_max np.min(x_indices), np.max(x_indices)y_min, y_max np.min(y_indices), np.max(y_indices)# add perturbation to bounding box coordinatesH, W gt2D.shapex_min max(0, x_min - random.randint(0, self.bbox_shift))x_max min(W, x_max random.randint(0, self.bbox_shift))y_min max(0, y_min - random.randint(0, self.bbox_shift))y_max min(H, y_max random.randint(0, self.bbox_shift))# bboxes np.array([x_min, y_min, x_max, y_max])bboxes np.array([[x_min, y_min], [x_max, y_max]])boxes_1 torch.reshape(torch.tensor(bboxes), [1, 1, -1, 2])input_label np.array([2, 3])boxes_1 torch.tensor(bboxes)[None, :] # boxes_1形状为[1, 2, 2]pts_labels torch.tensor(input_label)[None, :] # pts_labels形状为[1, 2]return {image: torch.tensor(img_padded).float(),gt2D: torch.tensor(gt2D[None, :, :]).long(),# bboxes: torch.tensor(bboxes[None, None, ...]).float(), # (B, 1, 4)bboxes: boxes_1, # efficient模型需要这样的维度 (B, 1, 2, 2)boxes_labels: pts_labels,image_name: img_name,new_size: torch.tensor(np.array([img_resize.shape[0], img_resize.shape[1]])).long(),original_size: torch.tensor(np.array([img_3c.shape[0], img_3c.shape[1]])).long()}def resize_longest_side(self, image):Expects a numpy array with shape HxWxC in uint8 format.long_side_length self.target_lengtholdh, oldw image.shape[0], image.shape[1]scale long_side_length * 1.0 / max(oldh, oldw)newh, neww oldh * scale, oldw * scaleneww, newh int(neww 0.5), int(newh 0.5)target_size (neww, newh)return cv2.resize(image, target_size, interpolationcv2.INTER_AREA)def pad_image(self, image):Expects a numpy array with shape HxWxC in uint8 format.# Padh, w image.shape[0], image.shape[1]padh self.image_size - hpadw self.image_size - wif len(image.shape) 3: ## Pad imageimage_padded np.pad(image, ((0, padh), (0, padw), (0, 0)))else: ## Pad gt maskimage_padded np.pad(image, ((0, padh), (0, padw)))return image_padded将Efficient-SAM代码和权重拷贝到服务器或者本地 新建一个Net文件夹将efficient-sam放下面 复制下面代码放Net下面 small_efficient_sam_encoder_config.py # from Dataset.Dataset import train_loader from Net.efficient_sam.efficient_sam_encoder import ImageEncoderViT from torch import nn, Tensorimg_size 1024 encoder_patch_size 16 encoder_depth 12 encoder_mlp_ratio 4.0 encoder_neck_dims [256, 256] decoder_max_num_input_points 6 decoder_transformer_depth 2 decoder_transformer_mlp_dim 2048 decoder_num_heads 8 decoder_upscaling_layer_dims [64, 32] num_multimask_outputs 3 iou_head_depth 3 iou_head_hidden_dim 256 activation gelu normalization_type layer_norm normalize_before_activation False small_efficient_sam_encoder ImageEncoderViT(img_sizeimg_size,patch_sizeencoder_patch_size,in_chans3,# small vitpatch_embed_dim384,normalization_typenormalization_type,depthencoder_depth,# small vitnum_heads6,mlp_ratioencoder_mlp_ratio,neck_dimsencoder_neck_dims,act_layernn.ReLU, ) tiny_efficient_sam_encoder_config.py # from Dataset.Dataset import train_loader from Net.efficient_sam.efficient_sam_encoder import ImageEncoderViT from torch import nn, Tensorimg_size 1024 encoder_patch_size 16 encoder_depth 12 encoder_mlp_ratio 4.0 encoder_neck_dims [256, 256] decoder_max_num_input_points 6 decoder_transformer_depth 2 decoder_transformer_mlp_dim 2048 decoder_num_heads 8 decoder_upscaling_layer_dims [64, 32] num_multimask_outputs 3 iou_head_depth 3 iou_head_hidden_dim 256 activation gelu normalization_type layer_norm normalize_before_activation False tiny_efficient_sam_encoder ImageEncoderViT(img_sizeimg_size,patch_sizeencoder_patch_size,in_chans3,# small vitpatch_embed_dim192,normalization_typenormalization_type,depthencoder_depth,# small vitnum_heads3,mlp_ratioencoder_mlp_ratio,neck_dimsencoder_neck_dims,act_layernn.ReLU, ) 去官网git下载权重 放到项目下面新建weights文件夹只需要下载这两个就行下载之后把zip解压 愉快训练 这里有一些参数需要根据自己的实际情况更改 单卡 # %% import os import random import monai from os import listdir, makedirs from os.path import join, exists, isfile, isdir, basename from glob import glob from tqdm import tqdm, trange from copy import deepcopy from time import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from datetime import datetimeimport cv2 import torch.nn.functional as Ffrom matplotlib import pyplot as plt import argparse# %% parser argparse.ArgumentParser() parser.add_argument(-data_root, typestr, defaulttrain_npy,helpPath to the npy data root. ) parser.add_argument(-pretrained_checkpoint, typestr, defaultlite_medsam.pth,helpPath to the pretrained Lite-MedSAM checkpoint. )parser.add_argument(-work_dir, typestr, default./workdir,helpPath to the working directory where checkpoints and logs will be saved. ) parser.add_argument(-num_epochs, typeint, default10,helpNumber of epochs to train. ) parser.add_argument(-batch_size, typeint, default4,helpBatch size. ) parser.add_argument(-num_workers, typeint, default8,helpNumber of workers for dataloader. ) parser.add_argument(-device, typestr, defaultcuda:1,helpDevice to train on. ) parser.add_argument(-bbox_shift, typeint, default5,helpPerturbation to bounding box coordinates during training. ) parser.add_argument(-lr, typefloat, default0.00005,helpLearning rate. ) parser.add_argument(-weight_decay, typefloat, default0.01,helpWeight decay. ) parser.add_argument(-iou_loss_weight, typefloat, default1.0,helpWeight of IoU loss. ) parser.add_argument(-seg_loss_weight, typefloat, default1.0,helpWeight of segmentation loss. ) parser.add_argument(-ce_loss_weight, typefloat, default1.0,helpWeight of cross entropy loss. ) parser.add_argument(--sanity_check, actionstore_true,helpWhether to do sanity check for dataloading. )args parser.parse_args() # %% work_dir args.work_dir data_root args.data_root medsam_lite_checkpoint args.pretrained_checkpoint num_epochs args.num_epochs batch_size args.batch_size num_workers args.num_workers device args.device bbox_shift args.bbox_shift lr args.lr weight_decay args.weight_decay iou_loss_weight args.iou_loss_weight seg_loss_weight args.seg_loss_weight ce_loss_weight args.ce_loss_weight do_sancheck args.sanity_check checkpoint args.resumemakedirs(work_dir, exist_okTrue)# %% torch.cuda.empty_cache() os.environ[OMP_NUM_THREADS] 4 # export OMP_NUM_THREADS4 os.environ[OPENBLAS_NUM_THREADS] 4 # export OPENBLAS_NUM_THREADS4 os.environ[MKL_NUM_THREADS] 6 # export MKL_NUM_THREADS6 os.environ[VECLIB_MAXIMUM_THREADS] 4 # export VECLIB_MAXIMUM_THREADS4 os.environ[NUMEXPR_NUM_THREADS] 6 # export NUMEXPR_NUM_THREADS6def show_mask(mask, ax, random_colorFalse):if random_color:color np.concatenate([np.random.random(3), np.array([0.45])], axis0)else:color np.array([251 / 255, 252 / 255, 30 / 255, 0.45])h, w mask.shape[-2:]mask_image mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_box(box, ax):x0, y0 box[0], box[1]w, h box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolorblue, facecolor(0, 0, 0, 0), lw2))def cal_iou(result, reference):intersection torch.count_nonzero(torch.logical_and(result, reference), dim[i for i in range(1, result.ndim)])union torch.count_nonzero(torch.logical_or(result, reference), dim[i for i in range(1, result.ndim)])iou intersection.float() / union.float()return iou.unsqueeze(1)# %% class NpyDataset(Dataset):def __init__(self, data_root, image_size256, bbox_shift5, data_augTrue):self.data_root data_rootself.gt_path join(data_root, gts)self.img_path join(data_root, imgs)self.gt_path_files sorted(glob(join(self.gt_path, *.npy), recursiveTrue))self.gt_path_files [file for file in self.gt_path_filesif isfile(join(self.img_path, basename(file)))]self.image_size image_sizeself.target_length image_sizeself.bbox_shift bbox_shiftself.data_aug data_augdef __len__(self):return len(self.gt_path_files)def __getitem__(self, index):img_name basename(self.gt_path_files[index])assert img_name basename(self.gt_path_files[index]), img gt name error self.gt_path_files[index] \self.npy_files[index]img_3c np.load(join(self.img_path, img_name), r, allow_pickleTrue) # (H, W, 3)img_resize self.resize_longest_side(img_3c)# Resizingimg_resize (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min1e-8,a_maxNone) # normalize to [0, 1], (H, W, 3img_padded self.pad_image(img_resize) # (256, 256, 3)# convert the shape to (3, H, W)img_padded np.transpose(img_padded, (2, 0, 1)) # (3, 256, 256)assert np.max(img_padded) 1.0 and np.min(img_padded) 0.0, image should be normalized to [0, 1]gt np.load(self.gt_path_files[index], r, allow_pickleTrue) # multiple labels [0, 1,4,5...], (256,256)gt cv2.resize(gt,(img_resize.shape[1], img_resize.shape[0]),interpolationcv2.INTER_NEAREST).astype(np.uint8)gt self.pad_image(gt) # (256, 256)label_ids np.unique(gt)[1:]try:gt2D np.uint8(gt random.choice(label_ids.tolist())) # only one label, (256, 256)except:print(img_name, label_ids.tolist(), label_ids.tolist())gt2D np.uint8(gt np.max(gt)) # only one label, (256, 256)# add data augmentation: random fliplr and random flipudif self.data_aug:if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-1))gt2D np.ascontiguousarray(np.flip(gt2D, axis-1))# print(DA with flip left right)if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-2))gt2D np.ascontiguousarray(np.flip(gt2D, axis-2))# print(DA with flip upside down)gt2D np.uint8(gt2D 0)y_indices, x_indices np.where(gt2D 0)x_min, x_max np.min(x_indices), np.max(x_indices)y_min, y_max np.min(y_indices), np.max(y_indices)# add perturbation to bounding box coordinatesH, W gt2D.shapex_min max(0, x_min - random.randint(0, self.bbox_shift))x_max min(W, x_max random.randint(0, self.bbox_shift))y_min max(0, y_min - random.randint(0, self.bbox_shift))y_max min(H, y_max random.randint(0, self.bbox_shift))# bboxes np.array([x_min, y_min, x_max, y_max])bboxes np.array([[x_min, y_min], [x_max, y_max]])boxes_1 torch.reshape(torch.tensor(bboxes), [1, 1, -1, 2])input_label np.array([2, 3])boxes_1 torch.tensor(bboxes)[None, :] # boxes_1形状为[1, 2, 2]pts_labels torch.tensor(input_label)[None, :] # pts_labels形状为[1, 2]return {image: torch.tensor(img_padded).float(),gt2D: torch.tensor(gt2D[None, :, :]).long(),# bboxes: torch.tensor(bboxes[None, None, ...]).float(), # (B, 1, 4)bboxes: boxes_1, # efficient模型需要这样的维度 (B, 1, 2, 2)boxes_labels: pts_labels,image_name: img_name,new_size: torch.tensor(np.array([img_resize.shape[0], img_resize.shape[1]])).long(),original_size: torch.tensor(np.array([img_3c.shape[0], img_3c.shape[1]])).long()}def resize_longest_side(self, image):Expects a numpy array with shape HxWxC in uint8 format.long_side_length self.target_lengtholdh, oldw image.shape[0], image.shape[1]scale long_side_length * 1.0 / max(oldh, oldw)newh, neww oldh * scale, oldw * scaleneww, newh int(neww 0.5), int(newh 0.5)target_size (neww, newh)return cv2.resize(image, target_size, interpolationcv2.INTER_AREA)def pad_image(self, image):Expects a numpy array with shape HxWxC in uint8 format.# Padh, w image.shape[0], image.shape[1]padh self.image_size - hpadw self.image_size - wif len(image.shape) 3: ## Pad imageimage_padded np.pad(image, ((0, padh), (0, padw), (0, 0)))else: ## Pad gt maskimage_padded np.pad(image, ((0, padh), (0, padw)))return image_padded# %% sanity test of dataset class if do_sancheck:tr_dataset NpyDataset(data_root, data_augTrue)tr_dataloader DataLoader(tr_dataset, batch_size8, shuffleTrue)for step, batch in enumerate(tr_dataloader):# show the example_, axs plt.subplots(1, 2, figsize(10, 10))idx random.randint(0, 4)image batch[image]gt batch[gt2D]bboxes batch[bboxes]names_temp batch[image_name]axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())show_mask(gt[idx].cpu().squeeze().numpy(), axs[0])show_box(bboxes[idx].numpy().squeeze(), axs[0])axs[0].axis(off)# set titleaxs[0].set_title(names_temp[idx])idx random.randint(4, 7)axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())show_mask(gt[idx].cpu().squeeze().numpy(), axs[1])show_box(bboxes[idx].numpy().squeeze(), axs[1])axs[1].axis(off)# set titleaxs[1].set_title(names_temp[idx])plt.subplots_adjust(wspace0.01, hspace0)plt.savefig(join(work_dir, medsam_lite-train_bbox_prompt_sanitycheck_DA.png),bbox_inchestight,dpi300)plt.close()break# %% class MedSAM_Lite(nn.Module):def __init__(self,image_encoder,mask_decoder,prompt_encoder):super().__init__()self.image_encoder image_encoderself.mask_decoder mask_decoderself.prompt_encoder prompt_encoderdef forward(self, image, boxes):image_embedding self.image_encoder(image) # (B, 256, 64, 64)sparse_embeddings, dense_embeddings self.prompt_encoder(pointsNone,boxesboxes,masksNone,)low_res_masks, iou_predictions self.mask_decoder(image_embeddingsimage_embedding, # (B, 256, 64, 64)image_peself.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)sparse_prompt_embeddingssparse_embeddings, # (B, 2, 256)dense_prompt_embeddingsdense_embeddings, # (B, 256, 64, 64)multimask_outputFalse,) # (B, 1, 256, 256)return low_res_masks, iou_predictionstorch.no_grad()def postprocess_masks(self, masks, new_size, original_size):Do cropping and resizing# Cropmasks masks[:, :, :new_size[0], :new_size[1]]# Resizemasks F.interpolate(masks,size(original_size[0], original_size[1]),modebilinear,align_cornersFalse,)return masksfrom Net.efficient_sam.build_efficient_sam import build_efficient_sam_vits medsam_lite_model build_efficient_sam_vits() medsam_lite_model medsam_lite_model.to(device) medsam_lite_model.train()# %% print(fMedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}) # %% optimizer optim.AdamW(medsam_lite_model.parameters(),lrlr,betas(0.9, 0.999),eps1e-08,weight_decayweight_decay, ) lr_scheduler optim.lr_scheduler.ReduceLROnPlateau(optimizer,modemin,factor0.9,patience5,cooldown0 ) seg_loss monai.losses.DiceLoss(sigmoidTrue, squared_predTrue, reductionmean) ce_loss nn.BCEWithLogitsLoss(reductionmean) iou_loss nn.MSELoss(reductionmean) # %% train_dataset NpyDataset(data_rootdata_root, data_augTrue) train_loader DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue, num_workersnum_workers, pin_memoryTrue)if checkpoint and isfile(checkpoint):print(fResuming from checkpoint {checkpoint})checkpoint torch.load(checkpoint)medsam_lite_model.load_state_dict(checkpoint[model], strictTrue)optimizer.load_state_dict(checkpoint[optimizer])start_epoch checkpoint[epoch]best_loss checkpoint[loss]print(fLoaded checkpoint from epoch {start_epoch}) else:start_epoch 0best_loss 1e10 # %% train_losses [] for epoch in range(start_epoch 1, num_epochs):epoch_loss [1e10 for _ in range(len(train_loader))]epoch_start_time time()pbar tqdm(train_loader)for step, batch in enumerate(pbar):image batch[image]gt2D batch[gt2D]boxes batch[bboxes]label_box batch[boxes_labels]optimizer.zero_grad()image, gt2D, boxes, label_box image.to(device), gt2D.to(device), boxes.to(device), label_box.to(device)logits_pred, iou_pred medsam_lite_model(image, boxes, label_box)gt2D torch.unsqueeze(gt2D, 2)gt2D gt2D.repeat(1, 1, 3, 1, 1)l_seg seg_loss(logits_pred, gt2D)l_ce ce_loss(logits_pred, gt2D.float())# mask_loss l_seg l_cemask_loss seg_loss_weight * l_seg ce_loss_weight * l_ceiou_gt cal_iou(torch.sigmoid(logits_pred) 0.5, gt2D.bool())l_iou iou_loss(iou_pred, iou_gt)# loss mask_loss l_iouloss mask_loss iou_loss_weight * l_iouepoch_loss[step] loss.item()loss.backward()optimizer.step()optimizer.zero_grad()pbar.set_description(fEpoch {epoch} at {datetime.now().strftime(%Y-%m-%d %H:%M:%S)}, loss: {loss.item():.4f})epoch_end_time time()epoch_loss_reduced sum(epoch_loss) / len(epoch_loss)train_losses.append(epoch_loss_reduced)lr_scheduler.step(epoch_loss_reduced)model_weights medsam_lite_model.state_dict()checkpoint {model: model_weights,epoch: epoch,optimizer: optimizer.state_dict(),loss: epoch_loss_reduced,best_loss: best_loss,}torch.save(checkpoint, join(work_dir, medsam_lite_latest.pth))if epoch_loss_reduced best_loss:print(fNew best loss: {best_loss:.4f} - {epoch_loss_reduced:.4f})best_loss epoch_loss_reducedcheckpoint[best_loss] best_losstorch.save(checkpoint, join(work_dir, medsam_lite_best.pth))epoch_loss_reduced 1e10# %% plot lossplt.plot(train_losses)plt.title(Dice Binary Cross Entropy IoU Loss)plt.xlabel(Epoch)plt.ylabel(Loss)plt.savefig(join(work_dir, train_loss.png))plt.close() 多卡 我这里有其他团队成员在用卡所以此时此刻一个rank我只能用23号gpu。根据自己情况更改 # %% import os import random import monai from os import listdir, makedirs from os.path import join, isfile, basename from glob import glob from tqdm import tqdm from copy import deepcopy from time import time from shutil import copyfile import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torch import multiprocessing as mp from torch import distributed as dist from datetime import datetime import cv2 import torch.nn.functional as Ffrom matplotlib import pyplot as plt import argparse import torch# print(torch.cuda.nccl.version())torch.cuda.empty_cache() os.environ[OMP_NUM_THREADS] 4 # export OMP_NUM_THREADS4 os.environ[OPENBLAS_NUM_THREADS] 4 # export OPENBLAS_NUM_THREADS4 os.environ[MKL_NUM_THREADS] 6 # export MKL_NUM_THREADS6 os.environ[VECLIB_MAXIMUM_THREADS] 4 # export VECLIB_MAXIMUM_THREADS4 os.environ[NUMEXPR_NUM_THREADS] 1 # export NUMEXPR_NUM_THREADS6 os.environ[MASTER_ADDR] # IP of node with rank 0 os.environ[MASTER_PORT] # Port on master node os.environ[WORLD_SIZE] 2 # Total number of processes os.environ[RANK] 0 # Rank of this processdef get_args():parser argparse.ArgumentParser()parser.add_argument(-i, --tr_npy_path, typestr,defaulttrain_npy,helpPath to training npy files; two subfolders: gts and imgs)parser.add_argument(-task_name, typestr, defaultMedSAM-Lite)parser.add_argument(-pretrained_checkpoint, typestr,helpPath to pretrained MedSAM-Lite checkpoint)parser.add_argument(-work_dir, typestr, default./work_dir_multi)parser.add_argument(--data_aug, actionstore_true, defaultFalse,helpuse data augmentation during training)# trainparser.add_argument(-num_epochs, typeint, default1000)parser.add_argument(-batch_size, typeint, default4)parser.add_argument(-num_workers, typeint, default1)# Optimizer parametersparser.add_argument(-weight_decay, typefloat, default0.01,helpweight decay (default: 0.01))parser.add_argument(-lr, typefloat, default0.0005, metavarLR,helplearning rate (absolute lr))## Distributed training argsparser.add_argument(-world_size, typeint, default2, helpworld size)parser.add_argument(-node_rank, default0, typeint, helpNode rank)parser.add_argument(-bucket_cap_mb, typeint, default25,helpThe amount of memory in Mb that DDP will accumulate before firing off gradient communication for the bucket (need to tune))parser.add_argument(-resume, typestr, default, requiredFalse,helpResuming training from a work_dir)parser.add_argument(-init_method, typestr, defauargs parser.parse_args()return argsdef show_mask(mask, ax, random_colorFalse):if random_color:color np.concatenate([np.random.random(3), np.array([0.45])], axis0)else:color np.array([251 / 255, 252 / 255, 30 / 255, 0.45])h, w mask.shape[-2:]mask_image mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_box(box, ax):x0, y0 box[0], box[1]w, h box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolorblue, facecolor(0, 0, 0, 0), lw2))torch.no_grad() def cal_iou(result, reference):intersection torch.count_nonzero(torch.logical_and(result, reference), dim[i for i in range(1, result.ndim)])union torch.count_nonzero(torch.logical_or(result, reference), dim[i for i in range(1, result.ndim)])iou intersection.float() / union.float()return iou.unsqueeze(1)def revert_sync_batchnorm(module: torch.nn.Module) - torch.nn.Module:# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547# Original author: Kapil Yedidi (kapily)converted_module moduleif isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):# Unfortunately, SyncBatchNorm does not store the original class - if it did# we could return the one that was originally created.converted_module nn.BatchNorm2d(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats)if module.affine:with torch.no_grad():converted_module.weight module.weightconverted_module.bias module.biasconverted_module.running_mean module.running_meanconverted_module.running_var module.running_varconverted_module.num_batches_tracked module.num_batches_trackedif hasattr(module, qconfig):converted_module.qconfig module.qconfigfor name, child in module.named_children():converted_module.add_module(name, revert_sync_batchnorm(child))del modulereturn converted_moduleclass NpyDataset(Dataset):def __init__(self, data_root, image_size256, bbox_shift10, data_augTrue):self.data_root data_rootself.gt_path join(data_root, gts)self.img_path join(data_root, imgs)self.gt_path_files sorted(glob(join(self.gt_path, *.npy), recursiveTrue))self.gt_path_files [file for file in self.gt_path_files if isfile(join(self.img_path, basename(file)))]self.image_size image_sizeself.target_length image_sizeself.bbox_shift bbox_shiftself.data_aug data_augdef __len__(self):return len(self.gt_path_files)def __getitem__(self, index):img_name basename(self.gt_path_files[index])assert img_name basename(self.gt_path_files[index]), img gt name error self.gt_path_files[index] \self.npy_files[index]img_3c np.load(join(self.img_path, img_name), r, allow_pickleTrue) # (H, W, 3)# Resizing and normalizationimg_resize self.resize_longest_side(img_3c)img_resize (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min1e-8,a_maxNone) # normalize to [0, 1], (H, W, 3img_padded self.pad_image(img_resize) # (256, 256, 3)# convert the shape to (3, H, W)img_padded np.transpose(img_padded, (2, 0, 1)) # (3, 256, 256)assert np.max(img_padded) 1.0 and np.min(img_padded) 0.0, image should be normalized to [0, 1]gt np.load(self.gt_path_files[index], r, allow_pickleTrue) # multiple labels [0, 1,4,5...], (256,256)assert gt.max() 1, gt should have at least one labelgt cv2.resize(gt,(img_resize.shape[1], img_resize.shape[0]),interpolationcv2.INTER_NEAREST).astype(np.uint8)gt self.pad_image(gt) # (256, 256)label_ids np.unique(gt)[1:]try:gt2D np.uint8(gt random.choice(label_ids.tolist())) # only one label, (256, 256)except:print(img_name, label_ids.tolist(), label_ids.tolist())gt2D np.uint8(gt np.max(gt)) # only one label, (256, 256)# add data augmentation: random fliplr and random flipudif self.data_aug:if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-1))gt2D np.ascontiguousarray(np.flip(gt2D, axis-1))# print(DA with flip left right)if random.random() 0.5:img_padded np.ascontiguousarray(np.flip(img_padded, axis-2))gt2D np.ascontiguousarray(np.flip(gt2D, axis-2))# print(DA with flip upside down)gt2D np.uint8(gt2D 0)y_indices, x_indices np.where(gt2D 0)x_min, x_max np.min(x_indices), np.max(x_indices)y_min, y_max np.min(y_indices), np.max(y_indices)# add perturbation to bounding box coordinatesH, W gt2D.shapex_min max(0, x_min - random.randint(0, self.bbox_shift))x_max min(W, x_max random.randint(0, self.bbox_shift))y_min max(0, y_min - random.randint(0, self.bbox_shift))y_max min(H, y_max random.randint(0, self.bbox_shift))# bboxes np.array([x_min, y_min, x_max, y_max])bboxes np.array([[x_min, y_min], [x_max, y_max]])input_label np.array([2, 3])# pts_labels torch.reshape(torch.tensor(input_label), [1, 1, -1])# bboxes已经是[2, 2]形状我们只需要增加一个批次维度boxes_1 torch.tensor(bboxes)[None, :] # boxes_1形状为[1, 2, 2]# input_label是[2]形状我们也是增加一个批次维度pts_labels torch.tensor(input_label)[None, :] # pts_labels形状为[1, 2]return {image: torch.tensor(img_padded).float(),gt2D: torch.tensor(gt2D[None, :, :]).long(),# bboxes: torch.tensor(bboxes[None, None, ...]).float(), # (B, 1, 4)bboxes: boxes_1, # efficient模型需要这样的维度 (B, 1, 2, 2)boxes_labels: pts_labels,image_name: img_name,new_size: torch.tensor(np.array([img_resize.shape[0], img_resize.shape[1]])).long(),original_size: torch.tensor(np.array([img_3c.shape[0], img_3c.shape[1]])).long()}def resize_longest_side(self, image):Expects a numpy array with shape HxWxC in uint8 format.long_side_length self.target_lengtholdh, oldw image.shape[0], image.shape[1]scale long_side_length * 1.0 / max(oldh, oldw)newh, neww oldh * scale, oldw * scaleneww, newh int(neww 0.5), int(newh 0.5)target_size (neww, newh)return cv2.resize(image, target_size, interpolationcv2.INTER_AREA)def pad_image(self, image):Expects a numpy array with shape HxWxC in uint8 format.# Padh, w image.shape[0], image.shape[1]padh self.image_size - hpadw self.image_size - wif len(image.shape) 3: ## Pad imageimage_padded np.pad(image, ((0, padh), (0, padw), (0, 0)))else: ## Pad gt maskimage_padded np.pad(image, ((0, padh), (0, padw)))return image_paddeddef collate_fn(batch):Collate function for PyTorch DataLoader.batch_dict {}for key in batch[0].keys():if key image_name:batch_dict[key] [sample[key] for sample in batch]else:batch_dict[key] torch.stack([sample[key] for sample in batch], dim0)return batch_dict# %% sanity test of dataset class def sanity_check_dataset(args):print(tr_npy_path, args.tr_npy_path)tr_dataset NpyDataset(args.tr_npy_path, data_augargs.data_aug)print(len(tr_dataset), len(tr_dataset))tr_dataloader DataLoader(tr_dataset, batch_size4, shuffleTrue, collate_fncollate_fn)makedirs(args.work_dir, exist_okTrue)for step, batch in enumerate(tr_dataloader):# print(image.shape, gt.shape, bboxes.shape)# show the example_, axs plt.subplots(1, 2, figsize(10, 10))idx random.randint(0, 4)image batch[image]gt batch[gt2D]bboxes batch[bboxes]names_temp batch[image_name]axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())show_mask(gt[idx].cpu().squeeze().numpy(), axs[0])show_box(bboxes[idx].numpy().squeeze(), axs[0])axs[0].axis(off)# set titleaxs[0].set_title(names_temp[idx])idx random.randint(4, 7)axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())show_mask(gt[idx].cpu().squeeze().numpy(), axs[1])show_box(bboxes[idx].numpy().squeeze(), axs[1])axs[1].axis(off)# set titleaxs[1].set_title(names_temp[idx])# plt.show() plt.subplots_adjust(wspace0.01, hspace0)plt.savefig(join(args.work_dir, medsam_lite-train_bbox_prompt_sanitycheck_DA.png),bbox_inchestight,dpi300)plt.close()break# %% class MedSAM_Lite(nn.Module):def __init__(self,image_encoder,mask_decoder,prompt_encoder):super().__init__()self.image_encoder image_encoderself.mask_decoder mask_decoderself.prompt_encoder prompt_encoderdef forward(self, image, boxes):image_embedding self.image_encoder(image) # (B, 256, 64, 64)sparse_embeddings, dense_embeddings self.prompt_encoder(pointsNone,boxesboxes,masksNone,)low_res_logits, iou_predictions self.mask_decoder(image_embeddingsimage_embedding, # (B, 256, 64, 64)image_peself.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)sparse_prompt_embeddingssparse_embeddings, # (B, 2, 256)dense_prompt_embeddingsdense_embeddings, # (B, 256, 64, 64)multimask_outputFalse,) # (B, 1, 256, 256)return low_res_logits, iou_predictionstorch.no_grad()def postprocess_masks(self, masks, new_size, original_size):Do cropping and resizing# Cropmasks masks[:, :, :new_size[0], :new_size[1]]# Resizemasks F.interpolate(masks,size(original_size[0], original_size[1]),modebilinear,align_cornersFalse,)return masksdef main(args):ngpus_per_node 2print(Spwaning processces)mp.spawn(main_worker, nprocsngpus_per_node, args(ngpus_per_node, args))def main_worker(gpu, ngpus_per_node, args):node_rank int(args.node_rank)adjusted_gpu gpu 1rank node_rank * ngpus_per_node adjusted_gpu-1# rank node_rank * ngpus_per_node gpuworld_size args.world_size# print(f[Rank {rank}]: Use GPU: {gpu} for training)print(f[Rank {rank}]: Use GPU: {adjusted_gpu} for training)is_main_host rank 0print(now 1)if is_main_host:print(now 2)run_id datetime.now().strftime(%Y%m%d-%H%M)model_save_path join(args.work_dir, args.task_name - run_id)makedirs(model_save_path, exist_okTrue)copyfile(__file__, join(model_save_path, run_id _ os.path.basename(__file__)))print(now 3)torch.cuda.set_device(adjusted_gpu)device torch.device(cuda:{}.format(adjusted_gpu))print(device)print(now 4)dist.init_process_group(backendnccl, init_methodargs.init_method, rankrank, world_sizeworld_size)print(now 5)num_epochs args.num_epochsbatch_size args.batch_sizenum_workers args.num_workersfrom Net.efficient_sam.build_efficient_sam import build_efficient_sam_vitsprint(now 6)medsam_lite_model build_efficient_sam_vits()medsam_lite_model medsam_lite_model.to(device)## Make sure theres only 2d BN layers, so that I can revert them properlyfor module in medsam_lite_model.modules():cls_name module.__class__.__name__if BatchNorm in cls_name:assert cls_name BatchNorm2dmedsam_lite_model nn.SyncBatchNorm.convert_sync_batchnorm(medsam_lite_model)medsam_lite_model nn.parallel.DistributedDataParallel(medsam_lite_model,device_ids[adjusted_gpu],output_deviceadjusted_gpu,find_unused_parametersTrue,bucket_cap_mbargs.bucket_cap_mb)medsam_lite_model.train()# %%print(fMedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())})# %%optimizer optim.AdamW(medsam_lite_model.parameters(),lrargs.lr,betas(0.9, 0.999),eps1e-08,weight_decayargs.weight_decay,)lr_scheduler optim.lr_scheduler.ReduceLROnPlateau(optimizer,modemin,factor0.9,patience5,cooldown0)seg_loss monai.losses.DiceLoss(sigmoidTrue, squared_predTrue, reductionmean)ce_loss nn.BCEWithLogitsLoss(reductionmean)iou_loss nn.MSELoss(reductionmean)# %%data_root args.tr_npy_pathtrain_dataset NpyDataset(data_rootdata_root, data_augTrue)train_sampler torch.utils.data.distributed.DistributedSampler(train_dataset)train_loader DataLoader(train_dataset,batch_sizebatch_size,shuffleFalse,num_workersnum_workers,pin_memoryTrue,samplertrain_sampler,collate_fncollate_fn)# %%if os.path.exists(args.resume):ckpt_folders sorted(listdir(args.resume))ckpt_folders [f for f in ckpt_folders if(f.startswith(args.task_name) and isfile(join(args.resume, f, medsam_lite_latest.pth)))]print(* * 20)print(existing ckpts in, args.resume, ckpt_folders)# find the latest ckpt folderstime_strings [f.split(args.task_name -)[-1] for f in ckpt_folders]dates [datetime.strptime(f, %Y%m%d-%H%M) for f in time_strings]latest_date max(dates)latest_ckpt join(args.work_dir, args.task_name - latest_date.strftime(%Y%m%d-%H%M),medsam_lite_latest.pth)print(Loading from, latest_ckpt)checkpoint torch.load(latest_ckpt, map_locationdevice)medsam_lite_model.module.load_state_dict(checkpoint[model])optimizer.load_state_dict(checkpoint[optimizer])start_epoch checkpoint[epoch] 1best_loss checkpoint[loss]print(fLoaded checkpoint from epoch {start_epoch})else:start_epoch 0best_loss 1e10train_losses []epoch_times []for epoch in range(start_epoch, num_epochs):epoch_loss [1e10 for _ in range(len(train_loader))]epoch_start_time time()pbar tqdm(train_loader)for step, batch in enumerate(pbar):image batch[image]gt2D batch[gt2D]boxes batch[bboxes]label_box batch[boxes_labels]optimizer.zero_grad()image, gt2D, boxes, label_box image.to(device), gt2D.to(device), boxes.to(device), label_box.to(device)logits_pred, iou_pred medsam_lite_model(image, boxes, label_box)gt2D torch.unsqueeze(gt2D, 2)gt2D gt2D.repeat(1, 1, 3, 1, 1)l_seg seg_loss(logits_pred, gt2D)l_ce ce_loss(logits_pred, gt2D.float())mask_loss l_seg l_cewith torch.no_grad():iou_gt cal_iou(torch.sigmoid(logits_pred) 0.5, gt2D.bool())l_iou iou_loss(iou_pred, iou_gt)loss mask_loss l_iouepoch_loss[step] loss.item()loss.backward()optimizer.step()optimizer.zero_grad()pbar.set_description(f[RANK {rank}] Epoch {epoch} at {datetime.now().strftime(%Y-%m-%d %H:%M:%S)}, loss: {loss.item():.4f})epoch_end_time time()epoch_duration epoch_end_time - epoch_start_timeepoch_times.append(epoch_duration)epoch_loss_world [None for _ in range(world_size)]dist.all_gather_object(epoch_loss_world, epoch_loss)epoch_loss_reduced np.vstack(epoch_loss_world).mean()train_losses.append(epoch_loss_reduced)lr_scheduler.step(epoch_loss_reduced)if is_main_host:module_revert_sync_BN revert_sync_batchnorm(deepcopy(medsam_lite_model.module))weights module_revert_sync_BN.state_dict()checkpoint {model: weights,epoch: epoch,optimizer: optimizer.state_dict(),loss: epoch_loss_reduced,best_loss: best_loss,}torch.save(checkpoint, join(model_save_path, medsam_lite_latest.pth))if epoch_loss_reduced best_loss:print(fNew best loss: {best_loss:.4f} - {epoch_loss_reduced:.4f})best_loss epoch_loss_reducedif is_main_host:checkpoint[best_loss] best_losstorch.save(checkpoint, join(model_save_path, medsam_lite_best.pth))dist.barrier()epoch_loss_reduced 1e10# %% plot lossif is_main_host:fig, axes plt.subplots(2, 1, figsize(10, 8))axes[0].title.set_text(Dice Binary Cross Entropy IoU Loss)axes[0].plot(train_losses)axes[0].set_ylabel(Loss)axes[1].plot(epoch_times)axes[1].title.set_text(Epoch Duration)axes[1].set_ylabel(Duration (s))axes[1].set_xlabel(Epoch)plt.tight_layout()plt.savefig(join(model_save_path, log.png))plt.close()dist.barrier()# %% if __name__ __main__:args get_args()# sanity_check_dataset(args)main(args)
http://www.zqtcl.cn/news/763743/

相关文章:

  • 公主岭网站建设seo网站推广技术
  • 网站建设一次crm管理
  • 电商网站设计公司优选亿企邦wordpress管理员头像
  • 医院做网站需要多少钱wordpress 模板 设计
  • 建设网站的规则建设公司网站的原则
  • 专业网站定制 北京龙泉驿网站seo
  • 网站标签是什么网站flash导入页
  • 城市网站建设摘要论文网站建设基本步骤包括哪些
  • 如何做招聘网站分析wordpress状态修改
  • 兰考网站建设微信运营是干嘛的
  • 网站ps照片怎么做的网站开发项目实训报告
  • 做流量网站it建设人才网
  • 杭州拱墅区网站建设推荐定制型网站建设
  • 网站建设需要达到什么样的效果上海营销网站推广多
  • 现代化公司网站建设长沙公司网站建立
  • 网站开发需要哪些人才辽宁奔之流建设工程有限公司网站
  • 做旅游产品的网站有哪些个人做搜索网站违法吗
  • 营销型网站的功能网站制作价钱多少
  • angularjs 网站模板工作感悟及心得
  • 福州 网站定制设计哈尔滨网站建设咨询
  • 酒吧网站模板创办网页
  • 外贸网站建设软件有哪些现在网站建设用什么语言
  • lnmp wordpress 主题不见高级seo课程
  • 成都哪家公司做网站最好杭州软件开发
  • 做网站多少宽带够wordpress编辑文章中图片
  • 无锡网站制作排名软件工程公司
  • 做网站国内好的服务器美食网站建设项目规划书
  • 三亚市住房和城乡建设厅网站江西电信网站备案
  • 联谊会总结网站建设对外宣传如何在家做电商
  • 360建站系统徐州建设银行网上银行个人网站