atss_assigner.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddlex.ppdet.core.workspace import register
  22. from ..bbox_utils import iou_similarity, batch_iou_similarity
  23. from ..bbox_utils import bbox_center
  24. from .utils import (check_points_inside_bboxes, compute_max_iou_anchor,
  25. compute_max_iou_gt)
  26. __all__ = ['ATSSAssigner']
  27. @register
  28. class ATSSAssigner(nn.Layer):
  29. """Bridging the Gap Between Anchor-based and Anchor-free Detection
  30. via Adaptive Training Sample Selection
  31. """
  32. __shared__ = ['num_classes']
  33. def __init__(self,
  34. topk=9,
  35. num_classes=80,
  36. force_gt_matching=False,
  37. eps=1e-9):
  38. super(ATSSAssigner, self).__init__()
  39. self.topk = topk
  40. self.num_classes = num_classes
  41. self.force_gt_matching = force_gt_matching
  42. self.eps = eps
  43. def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
  44. pad_gt_mask):
  45. gt2anchor_distances_list = paddle.split(
  46. gt2anchor_distances, num_anchors_list, axis=-1)
  47. num_anchors_index = np.cumsum(num_anchors_list).tolist()
  48. num_anchors_index = [0, ] + num_anchors_index[:-1]
  49. is_in_topk_list = []
  50. topk_idxs_list = []
  51. for distances, anchors_index in zip(gt2anchor_distances_list,
  52. num_anchors_index):
  53. num_anchors = distances.shape[-1]
  54. _, topk_idxs = paddle.topk(
  55. distances, self.topk, axis=-1, largest=False)
  56. topk_idxs_list.append(topk_idxs + anchors_index)
  57. is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
  58. axis=-2).astype(gt2anchor_distances.dtype)
  59. is_in_topk_list.append(is_in_topk * pad_gt_mask)
  60. is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
  61. topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
  62. return is_in_topk_list, topk_idxs_list
  63. @paddle.no_grad()
  64. def forward(self,
  65. anchor_bboxes,
  66. num_anchors_list,
  67. gt_labels,
  68. gt_bboxes,
  69. pad_gt_mask,
  70. bg_index,
  71. gt_scores=None,
  72. pred_bboxes=None):
  73. r"""This code is based on
  74. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  75. The assignment is done in following steps
  76. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  77. 2. compute center distance between all bbox and gt
  78. 3. on each pyramid level, for each gt, select k bbox whose center
  79. are closest to the gt center, so we total select k*l bbox as
  80. candidates for each gt
  81. 4. get corresponding iou for the these candidates, and compute the
  82. mean and std, set mean + std as the iou threshold
  83. 5. select these candidates whose iou are greater than or equal to
  84. the threshold as positive
  85. 6. limit the positive sample's center in gt
  86. 7. if an anchor box is assigned to multiple gts, the one with the
  87. highest iou will be selected.
  88. Args:
  89. anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
  90. "xmin, xmax, ymin, ymax" format
  91. num_anchors_list (List): num of anchors in each level
  92. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  93. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  94. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  95. bg_index (int): background index
  96. gt_scores (Tensor|None, float32) Score of gt_bboxes,
  97. shape(B, n, 1), if None, then it will initialize with one_hot label
  98. pred_bboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 4)
  99. Returns:
  100. assigned_labels (Tensor): (B, L)
  101. assigned_bboxes (Tensor): (B, L, 4)
  102. assigned_scores (Tensor): (B, L, C), if pred_bboxes is not None, then output ious
  103. """
  104. assert gt_labels.ndim == gt_bboxes.ndim and \
  105. gt_bboxes.ndim == 3
  106. num_anchors, _ = anchor_bboxes.shape
  107. batch_size, num_max_boxes, _ = gt_bboxes.shape
  108. # negative batch
  109. if num_max_boxes == 0:
  110. assigned_labels = paddle.full(
  111. [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype)
  112. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  113. assigned_scores = paddle.zeros(
  114. [batch_size, num_anchors, self.num_classes])
  115. return assigned_labels, assigned_bboxes, assigned_scores
  116. # 1. compute iou between gt and anchor bbox, [B, n, L]
  117. ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
  118. ious = ious.reshape([batch_size, -1, num_anchors])
  119. # 2. compute center distance between all anchors and gt, [B, n, L]
  120. gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
  121. anchor_centers = bbox_center(anchor_bboxes)
  122. gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \
  123. .norm(2, axis=-1).reshape([batch_size, -1, num_anchors])
  124. # 3. on each pyramid level, selecting topk closest candidates
  125. # based on the center distance, [B, n, L]
  126. is_in_topk, topk_idxs = self._gather_topk_pyramid(
  127. gt2anchor_distances, num_anchors_list, pad_gt_mask)
  128. # 4. get corresponding iou for the these candidates, and compute the
  129. # mean and std, 5. set mean + std as the iou threshold
  130. iou_candidates = ious * is_in_topk
  131. iou_threshold = paddle.index_sample(
  132. iou_candidates.flatten(stop_axis=-2),
  133. topk_idxs.flatten(stop_axis=-2))
  134. iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
  135. iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
  136. iou_threshold.std(axis=-1, keepdim=True)
  137. is_in_topk = paddle.where(iou_candidates > iou_threshold, is_in_topk,
  138. paddle.zeros_like(is_in_topk))
  139. # 6. check the positive sample's center in gt, [B, n, L]
  140. is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
  141. # select positive sample, [B, n, L]
  142. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  143. # 7. if an anchor box is assigned to multiple gts,
  144. # the one with the highest iou will be selected.
  145. mask_positive_sum = mask_positive.sum(axis=-2)
  146. if mask_positive_sum.max() > 1:
  147. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  148. [1, num_max_boxes, 1])
  149. is_max_iou = compute_max_iou_anchor(ious)
  150. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  151. mask_positive)
  152. mask_positive_sum = mask_positive.sum(axis=-2)
  153. # 8. make sure every gt_bbox matches the anchor
  154. if self.force_gt_matching:
  155. is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
  156. mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile(
  157. [1, num_max_boxes, 1])
  158. mask_positive = paddle.where(mask_max_iou, is_max_iou,
  159. mask_positive)
  160. mask_positive_sum = mask_positive.sum(axis=-2)
  161. assigned_gt_index = mask_positive.argmax(axis=-2)
  162. # assigned target
  163. batch_ind = paddle.arange(
  164. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  165. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  166. assigned_labels = paddle.gather(
  167. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  168. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  169. assigned_labels = paddle.where(
  170. mask_positive_sum > 0, assigned_labels,
  171. paddle.full_like(assigned_labels, bg_index))
  172. assigned_bboxes = paddle.gather(
  173. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  174. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  175. assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1)
  176. ind = list(range(self.num_classes + 1))
  177. ind.remove(bg_index)
  178. assigned_scores = paddle.index_select(
  179. assigned_scores, paddle.to_tensor(ind), axis=-1)
  180. if pred_bboxes is not None:
  181. # assigned iou
  182. ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
  183. ious = ious.max(axis=-2).unsqueeze(-1)
  184. assigned_scores *= ious
  185. elif gt_scores is not None:
  186. gather_scores = paddle.gather(
  187. gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
  188. gather_scores = gather_scores.reshape([batch_size, num_anchors])
  189. gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
  190. paddle.zeros_like(gather_scores))
  191. assigned_scores *= gather_scores.unsqueeze(-1)
  192. return assigned_labels, assigned_bboxes, assigned_scores