当前位置: 首页 > news >正文

猪八戒网网站开发需求如何创建属于自己的网站

猪八戒网网站开发需求,如何创建属于自己的网站,百度指数1000搜索量有多少,苏州建设监督网站【深度学习总结_02】在自己的数据集微调SAM 前言 SAM (Segment Anything Model)是Meta AI开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万图像和数十亿mask的庞大数据语料库上进行训练的#xff0c;这使得它非常强大。SAM能够为各种各样的图像…【深度学习总结_02】在自己的数据集微调SAM 前言 SAM (Segment Anything Model)是Meta AI开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万图像和数十亿mask的庞大数据语料库上进行训练的这使得它非常强大。SAM能够为各种各样的图像生成准确的分割mask。 SAM通常在自然图像上表现优异但是在特定领域如医疗影响遥感图像等由于训练数据集缺乏这些数据SAM的效果并不是理想。因此在特定数据集上微调SAM是十分有必要的。 准备工作 1安装好segment anything git clone https://github.com/facebookresearch/segment-anything.git cd segment-anything python setup.py install2安装lightning包它是轻量级的PyTorch库用于高性能人工智能研究的轻量级PyTorch包装器。本文基于它对SAM进行微调 pip install lightning使用的数据集下载地址https://han-seg2023.grand-challenge.org/它是一个多器官的医疗影像数据集当然你也可以使用自己的数据集 步骤 1、创建配置文件 该配置文件含有SAM的哪些部分需要训练以及数据集的相关配置如数据集位置具体配置如下在config.py文件中 from box import Box config {num_devices: 1,batch_size: 6,num_workers: 4,num_epochs: 20,save_interval: 2,resume: None,out_dir: 模型权重输出地址,opt: {learning_rate: 8e-4,weight_decay: 1e-4,decay_factor: 10,steps: [60000, 86666],warmup_steps: 250,},model: {type: vit_b,checkpoint: SAM的权重地址,freeze: {image_encoder: True,prompt_encoder: True,mask_decoder: True,},},dataset: {root_dir: 数据集的根目录,sample_num: 4,target_size: 1024} } cfg Box(config)其中freeze部分决定SAM的哪些部分冷却不用训练dataset则是数据集的相关配置sample_num表示采样的point的数目target_size则是输入SAM的图片大小。 这里使用了box这个包可以通过如下命令安装 pip install python-box2、构建数据集 该部分负责在数据集加载的时候选择哪些数据进行训练这里我选择器官mandible进行训练。 同时由于该数据是3D数据对数据进行切片处理将3D数据变成2D图像该部分代码为 class HaNDataset(Dataset):def __init__(self, cfg):super().__init__()self.gt_path os.path.join(cfg.dataset.root_dir, oar_3d)self.img_path os.path.join(cfg.dataset.root_dir, ct_3d)# 文件列表self.img_file_list sorted(os.listdir(self.img_path))self.gt_file_list sorted(os.listdir(self.gt_path))# 器官类别self.category [7]self.cat2names {7 : mandible}# 数据列表含所有切片self.data_list []for i in range(len(self.img_file_list)):img_file_path os.path.join(self.img_path, self.img_file_list[i])gt_file_path os.path.join(self.gt_path, self.gt_file_list[i])img_data nib.load(img_file_path).get_fdata()gt_data nib.load(gt_file_path).get_fdata()axial_num img_data.shape[2]for a in range(axial_num):a_gt_data gt_data[:, :, a]ps_gt_data np.zeros_like(a_gt_data)for c in self.category:region (a_gt_data c)if np.sum(region) 0:self.data_list.append([i, a, c])print(fData size is:{len(self.data_list)})# 输入SAM的尺寸要是这个self.target_size cfg.dataset.target_size# 正负样本点数目self.sample_point_num cfg.dataset.sample_numdef __len__(self):return len(self.data_list)由于HaN这个数据集的数据格式是nii文件其数据的范围是0-2000而图像的数据范围是0-255因此需要将数据范围截断并重新映射。 输入SAM的图像大小应为1024*1024因此需要将其resize成目标尺寸。 除此之外由于HaN并没有提供box和point提示因此还需要从mask中自动获得相应的提示。 这些部分的实现为都在HaNDataset当中 def convert_to_three_channels(self, image):# 创建一个具有相同尺寸的3通道图像数组three_channel_image np.zeros((image.shape[ 0 ], image.shape[ 1 ], 3 ), dtypenp.uint8)# 将原始单通道图像复制到每个通道for i in range(3):three_channel_image[:, :, i] imagereturn three_channel_image def __getitem__(self, idx):data_id self.data_list[idx]f_id data_id[0]axial_id data_id[1]category_id data_id[2]name self.cat2names[category_id]img_data_path os.path.join(self.img_path, self.img_file_list[f_id])gt_data_path os.path.join(self.gt_path, self.gt_file_list[f_id])# nii文件的数据范围是0-2000和图像的范围不符img_data nib.load(img_data_path).get_fdata()# 截断对于ct图像img_data[img_data (50 1024 - 200)] (50 1024 - 200)img_data[img_data (50 1024 200)] (50 1024 200)img_data (img_data - (50 1024 - 200)) / 400.0 * 255.0img_data img_data[:, :, axial_id]img_data self.convert_to_three_channels(img_data)all_gt_data nib.load(gt_data_path).get_fdata()[:, :, axial_id]gt_data np.zeros_like(all_gt_data)gt_data[all_gt_data category_id] 1# 将image和gt变为target sizeorg_size gt_data.shapetransforms train_transforms(self.target_size, org_size[0], org_size[1])augments transforms(imageimg_data, maskgt_data)img_data, gt_data augments[image].to(torch.float32), augments[mask].to(torch.int64)# 获得box,验证时max_pixel为0bbox_data get_boxes_from_mask(gt_data, max_pixel0)[0]# 获得point提示point_coords, point_labels init_point_sampling(gt_data, self.sample_point_num)return {org_size: torch.tensor(org_size),category : name,image: img_data,label : gt_data,bbox : bbox_data,point_coords: point_coords,point_labels: point_labels}获得box和point以及resize图像的代码为 def init_point_sampling(mask, get_point1):if isinstance(mask, torch.Tensor):mask mask.numpy()# Get coordinates of black/white pixelsfg_coords np.argwhere(mask 1)[:, ::-1]bg_coords np.argwhere(mask 0)[:, ::-1]fg_size len(fg_coords)bg_size len(bg_coords)if get_point 1:if fg_size 0:index np.random.randint(fg_size)fg_coord fg_coords[index]label 1else:index np.random.randint(bg_size)fg_coord bg_coords[index]label 0return torch.as_tensor([fg_coord.tolist()], dtypetorch.float), torch.as_tensor([label], dtypetorch.int)else:num_fg get_point // 2num_bg get_point - num_fgfg_indices np.random.choice(fg_size, sizenum_fg, replaceTrue)bg_indices np.random.choice(bg_size, sizenum_bg, replaceTrue)fg_coords fg_coords[fg_indices]bg_coords bg_coords[bg_indices]coords np.concatenate([fg_coords, bg_coords], axis0)labels np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int)indices np.random.permutation(get_point)coords, labels torch.as_tensor(coords[indices], dtypetorch.float), torch.as_tensor(labels[indices],dtypetorch.int)return coords, labels def get_boxes_from_mask(mask, box_num1, std0.1, max_pixel5):if isinstance(mask, torch.Tensor):mask mask.numpy()label_img label(mask)regions regionprops(label_img)# Iterate through all regions and get the bounding box coordinatesboxes [tuple(region.bbox) for region in regions]# If the generated number of boxes is greater than the number of categories,# sort them by region area and select the top n regionsif len(boxes) box_num:sorted_regions sorted(regions, keylambda x: x.area, reverseTrue)[:box_num]boxes [tuple(region.bbox) for region in sorted_regions]# If the generated number of boxes is less than the number of categories,# duplicate the existing boxeselif len(boxes) box_num:num_duplicates box_num - len(boxes)boxes [boxes[i % len(boxes)] for i in range(num_duplicates)]# Perturb each bounding box with noisenoise_boxes []for box in boxes:y0, x0, y1, x1 boxwidth, height abs(x1 - x0), abs(y1 - y0)# Calculate the standard deviation and maximum noise valuenoise_std min(width, height) * stdmax_noise min(max_pixel, int(noise_std * 5))# Add random noise to each coordinatetry:noise_x np.random.randint(-max_noise, max_noise)except:noise_x 0try:noise_y np.random.randint(-max_noise, max_noise)except:noise_y 0x0, y0 x0 noise_x, y0 noise_yx1, y1 x1 noise_x, y1 noise_ynoise_boxes.append((x0, y0, x1, y1))return torch.as_tensor(noise_boxes, dtypetorch.float) def train_transforms(img_size, ori_h, ori_w):transforms []transforms.append(A.Resize(int(img_size), int(img_size), interpolationcv2.INTER_NEAREST))transforms.append(ToTensorV2(p1.0))return A.Compose(transforms, p1.)3、构建SAM模型 因为我们已经安装好了segment anything因此可以直接调用相关模块然后组成一个生成mask的流程即可该部分代码为 import torch.nn as nn import torch.nn.functional as F from segment_anything import sam_model_registry from segment_anything import SamPredictor class Model(nn.Module):def __init__(self, cfg):super().__init__()self.cfg cfgdef setup(self):self.model sam_model_registry[self.cfg.model.type](checkpointself.cfg.model.checkpoint)self.model.train()if self.cfg.model.freeze.image_encoder:for name, param in self.model.image_encoder.named_parameters():param.requires_grad Falseif self.cfg.model.freeze.prompt_encoder:for name, param in self.model.prompt_encoder.named_parameters():param.requires_grad False# freeze mask decoder参数if self.cfg.model.freeze.mask_decoder:for name, param in self.model.mask_decoder.named_parameters():param.requires_grad Falsedef forward(self, images, bboxes, org_size, point_coords None, point_labels None):_, _, H, W images.shapeimage_embeddings self.model.image_encoder(images)pred_masks []ious []# 还要添加points,输入格式(points coords, points label): #coords:B,N,2 labels:B,N# 一个batch一个batch处理for embedding, bbox, coord, label in zip(image_embeddings, bboxes, point_coords, point_labels):bbox bbox.unsqueeze(0)coord coord.unsqueeze(0)label label.unsqueeze(0)point (coord, label)sparse_embeddings, dense_embeddings self.model.prompt_encoder(pointspoint,boxesbbox,masksNone,)low_res_masks, iou_predictions self.model.mask_decoder(image_embeddingsembedding.unsqueeze(0),image_peself.model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddingssparse_embeddings,dense_prompt_embeddingsdense_embeddings,multimask_outputFalse,)masks F.interpolate(low_res_masks,(H, W),modebilinear,align_cornersFalse,)pred_masks.append(masks.squeeze(1))ious.append(iou_predictions)return pred_masks, iousdef get_predictor(self):return SamPredictor(self.model)其中setup方法决定哪些参数需要进行训练哪些不用。 4、使用数据进行训练 首先使用lightning进行配置 import lightning as L from config import cfg fabric L.Fabric(acceleratorauto,devicescfg.num_devices,strategyauto,loggers[TensorBoardLogger(cfg.out_dir, namelightning-sam)]) fabric.launch() fabric.seed_everything(1337 fabric.global_rank)然后创建模型和加载数据集代码为 with fabric.device:model Model(cfg)model.setup() train_data HaNDataset(cfg) train_loader DataLoader(train_data, batch_sizecfg.batch_size, num_workerscfg.num_workers, shuffleTrue) train_data fabric._setup_dataloader(train_loader)接着创建优化器代码为 def configure_opt(cfg: Box, model: Model):def lr_lambda(step):if step cfg.opt.warmup_steps:return step / cfg.opt.warmup_stepselif step cfg.opt.steps[0]:return 1.0elif step cfg.opt.steps[1]:return 1 / cfg.opt.decay_factorelse:return 1 / (cfg.opt.decay_factor**2)optimizer torch.optim.Adam(model.model.parameters(), lrcfg.opt.learning_rate, weight_decaycfg.opt.weight_decay)scheduler torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)return optimizer, scheduler optimizer, scheduler configure_opt(cfg, model) model, optimizer fabric.setup(model, optimizer)最后遍历数据集进行训练这里使用的损失函数有Focal lossDice loss和IoU loss代码为 def train_sam(cfg: Box,fabric: L.Fabric,model: Model,optimizer: _FabricOptimizer,scheduler: _FabricOptimizer,train_dataloader: DataLoader, ) :The SAM training loop.focal_loss FocalLoss()dice_loss DiceLoss()# 从上次中断的地方训练start_epoch 1if cfg.resume:map_location cuda:%d % fabric.global_rankcheckpoint torch.load(cfg.resume, map_location{cuda:0: map_location})start_epoch checkpoint[epoch]network checkpoint[network]opt checkpoint[optimizer]sche checkpoint[scheduler]model.model.load_state_dict(network)optimizer.load_state_dict(opt)scheduler.load_state_dict(sche)fabric.print(fresume from:{cfg.resume})for epoch in range(start_epoch, cfg.num_epochs):batch_time AverageMeter(namebatch_time)data_time AverageMeter(namedata_time)focal_losses AverageMeter(namefocal_losses)dice_losses AverageMeter(namedice_losses)iou_losses AverageMeter(nameiou_losses)total_losses AverageMeter(nametotal_losses)end time.time()# 保存模型if epoch % cfg.save_interval 0:fabric.print(fSaving checkpoint to {cfg.out_dir})state_dict model.model.state_dict()checkpoint {epoch: epoch,network: state_dict,optimizer: optimizer.state_dict(),scheduler: scheduler.state_dict()}# 多卡环境下只在rank0的gpu上保存if fabric.global_rank 0:torch.save(checkpoint, os.path.join(cfg.out_dir, fepoch-{epoch:06d}-ckpt.pth))for iter, data in enumerate(train_dataloader):data_time.update(time.time() - end)images data[image]gt_masks data[label]bboxes data[bbox]batch_size images.shape[0]pred_masks, iou_predictions model(images, bboxes, data[point_coords], data[point_labels])num_masks sum(len(pred_mask) for pred_mask in pred_masks)loss_focal torch.tensor(0., devicefabric.device)loss_dice torch.tensor(0., devicefabric.device)loss_iou torch.tensor(0., devicefabric.device)for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions):batch_iou calc_iou(pred_mask, gt_mask)loss_focal focal_loss(pred_mask, gt_mask, num_masks)loss_dice dice_loss(pred_mask, gt_mask, num_masks)loss_iou F.mse_loss(iou_prediction, batch_iou, reductionsum) / num_masksloss_total 20. * loss_focal loss_dice loss_iouoptimizer.zero_grad()fabric.backward(loss_total)optimizer.step()scheduler.step()batch_time.update(time.time() - end)end time.time()focal_losses.update(loss_focal.item(), batch_size)dice_losses.update(loss_dice.item(), batch_size)iou_losses.update(loss_iou.item(), batch_size)total_losses.update(loss_total.item(), batch_size)fabric.print(fEpoch: [{epoch}][{iter1}/{len(train_dataloader)}]f | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]f | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]f | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]f | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]f | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]f | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})])通过以上步骤就可以对SAM进行微调了如果是对mask decoder进行微调显存占用大概在17G左右。 参考链接 lightning-sam
http://www.zqtcl.cn/news/721277/

相关文章:

  • lng企业自建站wordpress 分页 美化
  • 手机版网站如何做新闻类网站怎么做百度推广
  • 网站开发工程师 上海合肥网站到首页排名
  • 商城网站后续费用请人代做谷歌外贸网站
  • 汽车网站有哪些3d家装效果图制作软件
  • 荆门做网站公众号的公司网站百度不收录的原因
  • 专门做羽毛球的网站福州seo网站排名
  • 网站返回503的含义是门户网站开发合同
  • 自己做网站的成本要哪些东西wordpress模板如何管理系统
  • 做一般的网站要多久wordpress写文章页面无法显示
  • 人和兽做的网站视频汽车建设网站开发流程
  • 长春市建设工程造价管理协会网站厦门网站建设费用
  • 广东建设信息公开网站怎样策划一个营销型网站
  • 魔兽做图下载网站如何经营一个购物网站
  • 深圳做网站哪个平台好一级消防工程师考试题型
  • 网站婚礼服务态网站建设论文网站设计有限公司是干嘛的
  • 邯郸网站建设效果好广西做网站的公司
  • 网站logo上传营销型网站制作方案
  • 小说网站静态模板站长工具seo综合查询adc
  • 北京响应式网站做logo那个网站
  • 如何申请免费网站空间刚察县wap网站建设公司
  • 哪里有网站推广软件免费推广seo策略方法
  • 阿里云备案网站 网站名称怎么写京icp备案查询
  • 网站开发岗位思维导图alexa排名
  • 自适应网站建设济南济南网站建设公司
  • 巴州网站建设库尔勒网站建设钟爱网络杭州微信网站制作
  • 52做网站南京市住房城乡建设门户网站
  • 网站开发精品课程贵阳市白云区官方网站
  • seo整站优化服务会计培训班一般收费多少
  • 批量网站访问检测怎么做好手机网站开发