atss_assigner.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddlex.ppdet.core.workspace import register
  22. from ..ops import iou_similarity
  23. from ..bbox_utils import bbox_center
  24. from .utils import (pad_gt, check_points_inside_bboxes, compute_max_iou_anchor,
  25. compute_max_iou_gt)
  26. @register
  27. class ATSSAssigner(nn.Layer):
  28. """Bridging the Gap Between Anchor-based and Anchor-free Detection
  29. via Adaptive Training Sample Selection
  30. """
  31. __shared__ = ['num_classes']
  32. def __init__(self,
  33. topk=9,
  34. num_classes=80,
  35. force_gt_matching=False,
  36. eps=1e-9):
  37. super(ATSSAssigner, self).__init__()
  38. self.topk = topk
  39. self.num_classes = num_classes
  40. self.force_gt_matching = force_gt_matching
  41. self.eps = eps
  42. def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
  43. pad_gt_mask):
  44. pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)
  45. gt2anchor_distances_list = paddle.split(
  46. gt2anchor_distances, num_anchors_list, axis=-1)
  47. num_anchors_index = np.cumsum(num_anchors_list).tolist()
  48. num_anchors_index = [0, ] + num_anchors_index[:-1]
  49. is_in_topk_list = []
  50. topk_idxs_list = []
  51. for distances, anchors_index in zip(gt2anchor_distances_list,
  52. num_anchors_index):
  53. num_anchors = distances.shape[-1]
  54. topk_metrics, topk_idxs = paddle.topk(
  55. distances, self.topk, axis=-1, largest=False)
  56. topk_idxs_list.append(topk_idxs + anchors_index)
  57. topk_idxs = paddle.where(pad_gt_mask, topk_idxs,
  58. paddle.zeros_like(topk_idxs))
  59. is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
  60. is_in_topk = paddle.where(is_in_topk > 1,
  61. paddle.zeros_like(is_in_topk),
  62. is_in_topk)
  63. is_in_topk_list.append(
  64. is_in_topk.astype(gt2anchor_distances.dtype))
  65. is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
  66. topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
  67. return is_in_topk_list, topk_idxs_list
  68. @paddle.no_grad()
  69. def forward(self,
  70. anchor_bboxes,
  71. num_anchors_list,
  72. gt_labels,
  73. gt_bboxes,
  74. bg_index,
  75. gt_scores=None):
  76. r"""This code is based on
  77. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  78. The assignment is done in following steps
  79. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  80. 2. compute center distance between all bbox and gt
  81. 3. on each pyramid level, for each gt, select k bbox whose center
  82. are closest to the gt center, so we total select k*l bbox as
  83. candidates for each gt
  84. 4. get corresponding iou for the these candidates, and compute the
  85. mean and std, set mean + std as the iou threshold
  86. 5. select these candidates whose iou are greater than or equal to
  87. the threshold as positive
  88. 6. limit the positive sample's center in gt
  89. 7. if an anchor box is assigned to multiple gts, the one with the
  90. highest iou will be selected.
  91. Args:
  92. anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
  93. "xmin, xmax, ymin, ymax" format
  94. num_anchors_list (List): num of anchors in each level
  95. gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1)
  96. gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4)
  97. bg_index (int): background index
  98. gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes,
  99. shape(B, n, 1), if None, then it will initialize with one_hot label
  100. Returns:
  101. assigned_labels (Tensor): (B, L)
  102. assigned_bboxes (Tensor): (B, L, 4)
  103. assigned_scores (Tensor): (B, L, C)
  104. """
  105. gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt(
  106. gt_labels, gt_bboxes, gt_scores)
  107. assert gt_labels.ndim == gt_bboxes.ndim and \
  108. gt_bboxes.ndim == 3
  109. num_anchors, _ = anchor_bboxes.shape
  110. batch_size, num_max_boxes, _ = gt_bboxes.shape
  111. # negative batch
  112. if num_max_boxes == 0:
  113. assigned_labels = paddle.full([batch_size, num_anchors], bg_index)
  114. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  115. assigned_scores = paddle.zeros(
  116. [batch_size, num_anchors, self.num_classes])
  117. return assigned_labels, assigned_bboxes, assigned_scores
  118. # 1. compute iou between gt and anchor bbox, [B, n, L]
  119. ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
  120. ious = ious.reshape([batch_size, -1, num_anchors])
  121. # 2. compute center distance between all anchors and gt, [B, n, L]
  122. gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
  123. anchor_centers = bbox_center(anchor_bboxes)
  124. gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \
  125. .norm(2, axis=-1).reshape([batch_size, -1, num_anchors])
  126. # 3. on each pyramid level, selecting topk closest candidates
  127. # based on the center distance, [B, n, L]
  128. is_in_topk, topk_idxs = self._gather_topk_pyramid(
  129. gt2anchor_distances, num_anchors_list, pad_gt_mask)
  130. # 4. get corresponding iou for the these candidates, and compute the
  131. # mean and std, 5. set mean + std as the iou threshold
  132. iou_candidates = ious * is_in_topk
  133. iou_threshold = paddle.index_sample(
  134. iou_candidates.flatten(stop_axis=-2),
  135. topk_idxs.flatten(stop_axis=-2))
  136. iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
  137. iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
  138. iou_threshold.std(axis=-1, keepdim=True)
  139. is_in_topk = paddle.where(
  140. iou_candidates > iou_threshold.tile([1, 1, num_anchors]),
  141. is_in_topk, paddle.zeros_like(is_in_topk))
  142. # 6. check the positive sample's center in gt, [B, n, L]
  143. is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
  144. # select positive sample, [B, n, L]
  145. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  146. # 7. if an anchor box is assigned to multiple gts,
  147. # the one with the highest iou will be selected.
  148. mask_positive_sum = mask_positive.sum(axis=-2)
  149. if mask_positive_sum.max() > 1:
  150. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  151. [1, num_max_boxes, 1])
  152. is_max_iou = compute_max_iou_anchor(ious)
  153. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  154. mask_positive)
  155. mask_positive_sum = mask_positive.sum(axis=-2)
  156. # 8. make sure every gt_bbox matches the anchor
  157. if self.force_gt_matching:
  158. is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
  159. mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile(
  160. [1, num_max_boxes, 1])
  161. mask_positive = paddle.where(mask_max_iou, is_max_iou,
  162. mask_positive)
  163. mask_positive_sum = mask_positive.sum(axis=-2)
  164. assigned_gt_index = mask_positive.argmax(axis=-2)
  165. assert mask_positive_sum.max() == 1, \
  166. ("one anchor just assign one gt, but received not equals 1. "
  167. "Received: %f" % mask_positive_sum.max().item())
  168. # assigned target
  169. batch_ind = paddle.arange(
  170. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  171. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  172. assigned_labels = paddle.gather(
  173. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  174. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  175. assigned_labels = paddle.where(
  176. mask_positive_sum > 0, assigned_labels,
  177. paddle.full_like(assigned_labels, bg_index))
  178. assigned_bboxes = paddle.gather(
  179. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  180. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  181. assigned_scores = F.one_hot(assigned_labels, self.num_classes)
  182. if gt_scores is not None:
  183. gather_scores = paddle.gather(
  184. pad_gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
  185. gather_scores = gather_scores.reshape([batch_size, num_anchors])
  186. gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
  187. paddle.zeros_like(gather_scores))
  188. assigned_scores *= gather_scores.unsqueeze(-1)
  189. return assigned_labels, assigned_bboxes, assigned_scores