网站建设技术员,网站建设ydwzjs,登录wordpress,网站备案抽查通过在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box.理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch不熟悉,所以在这篇文…在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box.理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch不熟悉,所以在这篇文章里,关于其中涉及的一些pytorch中的函数的用法我都已经用加粗标示了并且给出了相应的链接,测试代码等.obj score threshold我们设置一个obj score thershold,超过这个值的才认为是有效的.conf_mask (prediction[:,:,4] confidence).float().unsqueeze(2)prediction prediction*conf_maskprediction是1*boxnum*boxattrprediction[:,:,4]是1*boxnum 元素值为boxattr的index4的那个值.torch中的Tensor index和numpy是类似的,参看下列代码输出import torchx torch.Tensor(1,3,10) # Create an un-initialized Tensor of size 2x3print(x)print(x.shape) # Print out the Tensory x[:,:,4]print(y)print(y.shape)z x[:,:,4:6]print(z)print(z.shape)print((y0.5).float().unsqueeze(2))#### 输出如下tensor([[[2.5226e-18, 1.6898e-04, 1.0413e-11, 7.7198e-10, 1.0549e-08,4.0516e-11, 1.0681e-05, 2.9575e-18, 6.7333e22, 1.7591e22],[1.7184e25, 4.3222e27, 6.1972e-04, 7.2443e22, 1.7728e28,7.0367e22, 5.9018e-10, 2.6540e-09, 1.2972e-11, 5.3370e-08],[2.7001e-06, 2.6801e-09, 4.1292e-05, 2.1511e23, 3.2770e-09,2.5125e-18, 7.7052e31, 1.9447e31, 5.0207e28, 1.1492e-38]]])torch.Size([1, 3, 10])tensor([[1.0549e-08, 1.7728e28, 3.2770e-09]])torch.Size([1, 3])tensor([[[1.0549e-08, 4.0516e-11],[1.7728e28, 7.0367e22],[3.2770e-09, 2.5125e-18]]])torch.Size([1, 3, 2])tensor([[[0.],[0.],[0.]]])Squeeze and unsqueeze 降低维度,升高维度.t torch.ones(2,1,2,1) # Size 2x1x2x1r torch.squeeze(t) # Size 2x2r torch.squeeze(t, 1) # Squeeze dimension 1: Size 2x2x1# Un-squeeze a dimensionx torch.Tensor([1, 2, 3])r torch.unsqueeze(x, 0) # Size: 1x3 表示在第0个维度添加1维r torch.unsqueeze(x, 1) # Size: 3x1 表示在第1个维度添加1维这样prediction中objscorenms#得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)box_corner prediction.new(prediction.shape)box_corner[:,:,0] (prediction[:,:,0] - prediction[:,:,2]/2)box_corner[:,:,1] (prediction[:,:,1] - prediction[:,:,3]/2)box_corner[:,:,2] (prediction[:,:,0] prediction[:,:,2]/2)box_corner[:,:,3] (prediction[:,:,1] prediction[:,:,3]/2)prediction[:,:,:4] box_corner[:,:,:4]原始的prediction中boxattr存放的是x,y,w,h,...,不方便我们处理,我们将其转换成(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)接下来我们挨个处理每一张图片对应的feature map.batch_size prediction.size(0)write Falsefor ind in range(batch_size):#image_pred.shapeboxnum\*boxattrimage_pred prediction[ind] #image Tensor box_num*box_attr#confidence threshholding#NMS#返回每一行的最大值,及最大值所在的列.max_conf, max_conf_score torch.max(image_pred[:,5:5 num_classes], 1)#升级成和image_pred同样的维度max_conf max_conf.float().unsqueeze(1)max_conf_score max_conf_score.float().unsqueeze(1)seq (image_pred[:,:5], max_conf, max_conf_score)#沿着列的方向拼接. 现在image_pred变成boxnum\*7image_pred torch.cat(seq, 1)这里涉及到torch.max的用法,参见https://blog.csdn.net/Z_lbj/article/details/79766690torch.max(input, dim, keepdimFalse, outNone) - (Tensor, LongTensor)按维度dim 返回最大值.可以这么记忆,沿着第dim维度比较.torch.max(0)即沿着行的方向比较,即得到每列的最大值.假设input是二维矩阵,即行*列,行是第0维,列是第一维.torch.max(a,0) 返回每一列中最大值的那个元素且返回索引(返回最大元素在这一列的行索引)torch.max(a,1) 返回每一行中最大值的那个元素且返回其索引(返回最大元素在这一行的列索引)ctorch.Tensor([[1,2,3],[6,5,4]])print(c)a,btorch.max(c,1)print(a)print(b)##输出如下:tensor([[1., 2., 3.],[6., 5., 4.]])tensor([3., 6.])tensor([2, 0])torch.cat(tensors, dim0, outNone) → Tensor x torch.randn(2, 3) xtensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]]) torch.cat((x, x, x), 0)tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]]) torch.cat((x, x, x), 1)tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,-1.0969, -0.4614],[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,-0.5790, 0.1497]])接下来我们只处理obj_score非0的数据(obj_scorenon_zero_ind (torch.nonzero(image_pred[:,4]))try:image_pred_ image_pred[non_zero_ind.squeeze(),:].view(-1,7)except:continue#For PyTorch 0.4 compatibility#Since the above code with not raise exception for no detection#as scalars are supported in PyTorch 0.4if image_pred_.shape[0] 0:continueok,接下来我们对每一种class做nms.首先取到我们有哪些类别#Get the various classes detected in the imageimg_classes unique(image_pred_[:,-1]) # -1 index holds the class index然后依次对每一种类别做处理for cls in img_classes:#perform NMS#get the detections with one particular class#取出当前class为当前class且class prob!0的行cls_mask image_pred_*(image_pred_[:,-1] cls).float().unsqueeze(1)class_mask_ind torch.nonzero(cls_mask[:,-2]).squeeze()image_pred_class image_pred_[class_mask_ind].view(-1,7)#sort the detections such that the entry with the maximum objectness#confidence is at the top#按照obj score从高到低做排序conf_sort_index torch.sort(image_pred_class[:,4], descending True )[1]image_pred_class image_pred_class[conf_sort_index]idx image_pred_class.size(0) #Number of detectionsfor i in range(idx):#Get the IOUs of all boxes that come after the one we are looking at#in the looptry:#计算第i个和其后每一行的的iouious bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i1:])except ValueError:breakexcept IndexError:break#Zero out all the detections that have IoU treshhold#把与第i行iounms_conf的认为是同一个目标的box,将其转成0iou_mask (ious nms_conf).float().unsqueeze(1)image_pred_class[i1:] * iou_mask#把iounms_conf的移除掉non_zero_ind torch.nonzero(image_pred_class[:,4]).squeeze()image_pred_class image_pred_class[non_zero_ind].view(-1,7)batch_ind image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) #Repeat the batch_id for as many detections of the class cls in the imageseq batch_ind, image_pred_class其中计算iou的代码如下,不多解释了.iou交叠面积/总面积def bbox_iou(box1, box2):Returns the IoU of two bounding boxes#Get the coordinates of bounding boxesb1_x1, b1_y1, b1_x2, b1_y2 box1[:,0], box1[:,1], box1[:,2], box1[:,3]b2_x1, b2_y1, b2_x2, b2_y2 box2[:,0], box2[:,1], box2[:,2], box2[:,3]#get the corrdinates of the intersection rectangleinter_rect_x1 torch.max(b1_x1, b2_x1)inter_rect_y1 torch.max(b1_y1, b2_y1)inter_rect_x2 torch.min(b1_x2, b2_x2)inter_rect_y2 torch.min(b1_y2, b2_y2)#Intersection areainter_area torch.clamp(inter_rect_x2 - inter_rect_x1 1, min0) * torch.clamp(inter_rect_y2 - inter_rect_y1 1, min0)#Union Areab1_area (b1_x2 - b1_x1 1)*(b1_y2 - b1_y1 1)b2_area (b2_x2 - b2_x1 1)*(b2_y2 - b2_y1 1)iou inter_area / (b1_area b2_area - inter_area)return ioutensor index操作用法如下:image_pred_ torch.Tensor([[1,2,3,4,9],[5,6,7,8,9]])#print(image_pred_[:,-1] 9)has_9 (image_pred_[:,-1] 9)print(has_9)###执行顺序是(image_pred_[:,-1] 9).float().unsqueeze(1) 再做tensor乘法cls_mask image_pred_*(image_pred_[:,-1] 9).float().unsqueeze(1)print(cls_mask)class_mask_ind torch.nonzero(cls_mask[:,-2]).squeeze()image_pred_class image_pred_[class_mask_ind]输出如下:tensor([1, 1], dtypetorch.uint8)tensor([[1., 2., 3., 4., 9.],[5., 6., 7., 8., 9.]])torch.sort用法如下:dtorch.Tensor([[1,2,3],[6,5,4]])ed[:,2]print(e)print(torch.sort(e))输出tensor([3., 4.])torch.return_types.sort(valuestensor([3., 4.]),indicestensor([0, 1]))总结一下我们做nms的流程每一个image,会预测出N个detetction信息,包括41C(4个坐标信息,1个obj score以及C个class probability)首先过滤掉obj_score confidence的行每一行只取class probability最高的作为预测出来的类别将所有的预测按照obj_score从大到小排序循环每一种类别,开始做nms比较第一个box与其后所有box的iou,删除iouthreshold的box,即剔除所有相似box比较下一个box与其后所有box的iou,删除所有与该box相似的box不断重复上述过程,直至不再有相似box至此,实现了当前处理的类别的多个box均是独一无二的box.write_results最终的返回值是一个n*8的tensor,其中8是(batch_index,4个坐标,1个objscore,1个class prob,一个class index)def write_results(prediction, confidence, num_classes, nms_conf 0.4):print(prediction.shape,prediction.shape)#将obj_score confidence的行置为0conf_mask (prediction[:,:,4] confidence).float().unsqueeze(2)prediction prediction*conf_mask#得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)box_corner prediction.new(prediction.shape)box_corner[:,:,0] (prediction[:,:,0] - prediction[:,:,2]/2)box_corner[:,:,1] (prediction[:,:,1] - prediction[:,:,3]/2)box_corner[:,:,2] (prediction[:,:,0] prediction[:,:,2]/2)box_corner[:,:,3] (prediction[:,:,1] prediction[:,:,3]/2)#修改prediction第三个维度的前四列prediction[:,:,:4] box_corner[:,:,:4]batch_size prediction.size(0)write Falsefor ind in range(batch_size):#image_pred.shapeboxnum\*boxattrimage_pred prediction[ind] #image Tensor#confidence threshholding#NMS##取出每一行的class score最大的一个max_conf_score,max_conf torch.max(image_pred[:,5:5 num_classes], 1)max_conf max_conf.float().unsqueeze(1)max_conf_score max_conf_score.float().unsqueeze(1)seq (image_pred[:,:5], max_conf_score, max_conf)image_pred torch.cat(seq, 1) #现在变成7列,分别为左上角x,左上角y,右下角x,右下角y,obj score,最大probabilty,相应的class indexprint(image_pred.shape)non_zero_ind (torch.nonzero(image_pred[:,4]))try:image_pred_ image_pred[non_zero_ind.squeeze(),:].view(-1,7)except:continue#For PyTorch 0.4 compatibility#Since the above code with not raise exception for no detection#as scalars are supported in PyTorch 0.4if image_pred_.shape[0] 0:continue#Get the various classes detected in the imageimg_classes unique(image_pred_[:,-1]) # -1 index holds the class indexfor cls in img_classes:#perform NMS#get the detections with one particular class#取出当前class为当前class且class prob!0的行cls_mask image_pred_*(image_pred_[:,-1] cls).float().unsqueeze(1)class_mask_ind torch.nonzero(cls_mask[:,-2]).squeeze()image_pred_class image_pred_[class_mask_ind].view(-1,7)#sort the detections such that the entry with the maximum objectness#confidence is at the top#按照obj score从高到低做排序conf_sort_index torch.sort(image_pred_class[:,4], descending True )[1]image_pred_class image_pred_class[conf_sort_index]idx image_pred_class.size(0) #Number of detectionsfor i in range(idx):#Get the IOUs of all boxes that come after the one we are looking at#in the looptry:#计算第i个和其后每一行的的iouious bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i1:])except ValueError:breakexcept IndexError:break#Zero out all the detections that have IoU treshhold#把与第i行iounms_conf的认为是同一个目标的box,将其转成0iou_mask (ious nms_conf).float().unsqueeze(1)image_pred_class[i1:] * iou_mask#把iounms_conf的移除掉non_zero_ind torch.nonzero(image_pred_class[:,4]).squeeze()image_pred_class image_pred_class[non_zero_ind].view(-1,7)batch_ind image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) #Repeat the batch_id for as many detections of the class cls in the imageseq batch_ind, image_pred_classif not write:output torch.cat(seq,1) #沿着列方向,shape 1*8write Trueelse:out torch.cat(seq,1)output torch.cat((output,out)) #沿着行方向 shape n*8try:return outputexcept:return 0