target.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import six
  15. import math
  16. import numpy as np
  17. import paddle
  18. from ..bbox_utils import bbox2delta, bbox_overlaps
  19. import copy
  20. def rpn_anchor_target(anchors,
  21. gt_boxes,
  22. rpn_batch_size_per_im,
  23. rpn_positive_overlap,
  24. rpn_negative_overlap,
  25. rpn_fg_fraction,
  26. use_random=True,
  27. batch_size=1,
  28. ignore_thresh=-1,
  29. is_crowd=None,
  30. weights=[1., 1., 1., 1.]):
  31. tgt_labels = []
  32. tgt_bboxes = []
  33. tgt_deltas = []
  34. for i in range(batch_size):
  35. gt_bbox = gt_boxes[i]
  36. is_crowd_i = is_crowd[i] if is_crowd else None
  37. # Step1: match anchor and gt_bbox
  38. matches, match_labels = label_box(
  39. anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
  40. ignore_thresh, is_crowd_i)
  41. # Step2: sample anchor
  42. fg_inds, bg_inds = subsample_labels(match_labels,
  43. rpn_batch_size_per_im,
  44. rpn_fg_fraction, 0, use_random)
  45. # Fill with the ignore label (-1), then set positive and negative labels
  46. labels = paddle.full(match_labels.shape, -1, dtype='int32')
  47. if bg_inds.shape[0] > 0:
  48. labels = paddle.scatter(labels, bg_inds,
  49. paddle.zeros_like(bg_inds))
  50. if fg_inds.shape[0] > 0:
  51. labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
  52. # Step3: make output
  53. if gt_bbox.shape[0] == 0:
  54. matched_gt_boxes = paddle.zeros([0, 4])
  55. tgt_delta = paddle.zeros([0, 4])
  56. else:
  57. matched_gt_boxes = paddle.gather(gt_bbox, matches)
  58. tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
  59. matched_gt_boxes.stop_gradient = True
  60. tgt_delta.stop_gradient = True
  61. labels.stop_gradient = True
  62. tgt_labels.append(labels)
  63. tgt_bboxes.append(matched_gt_boxes)
  64. tgt_deltas.append(tgt_delta)
  65. return tgt_labels, tgt_bboxes, tgt_deltas
  66. def label_box(anchors,
  67. gt_boxes,
  68. positive_overlap,
  69. negative_overlap,
  70. allow_low_quality,
  71. ignore_thresh,
  72. is_crowd=None):
  73. iou = bbox_overlaps(gt_boxes, anchors)
  74. n_gt = gt_boxes.shape[0]
  75. if n_gt == 0 or is_crowd is None:
  76. n_gt_crowd = 0
  77. else:
  78. n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
  79. if iou.shape[0] == 0 or n_gt_crowd == n_gt:
  80. # No truth, assign everything to background
  81. default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
  82. default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
  83. return default_matches, default_match_labels
  84. # if ignore_thresh > 0, remove anchor if it is closed to
  85. # one of the crowded ground-truth
  86. if n_gt_crowd > 0:
  87. N_a = anchors.shape[0]
  88. ones = paddle.ones([N_a])
  89. mask = is_crowd * ones
  90. if ignore_thresh > 0:
  91. crowd_iou = iou * mask
  92. valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
  93. axis=0) > 0).cast('float32')
  94. iou = iou * (1 - valid) - valid
  95. # ignore the iou between anchor and crowded ground-truth
  96. iou = iou * (1 - mask) - mask
  97. matched_vals, matches = paddle.topk(iou, k=1, axis=0)
  98. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  99. # set ignored anchor with iou = -1
  100. neg_cond = paddle.logical_and(matched_vals > -1,
  101. matched_vals < negative_overlap)
  102. match_labels = paddle.where(neg_cond,
  103. paddle.zeros_like(match_labels), match_labels)
  104. match_labels = paddle.where(matched_vals >= positive_overlap,
  105. paddle.ones_like(match_labels), match_labels)
  106. if allow_low_quality:
  107. highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
  108. pred_inds_with_highest_quality = paddle.logical_and(
  109. iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
  110. 0, keepdim=True)
  111. match_labels = paddle.where(pred_inds_with_highest_quality > 0,
  112. paddle.ones_like(match_labels),
  113. match_labels)
  114. matches = matches.flatten()
  115. match_labels = match_labels.flatten()
  116. return matches, match_labels
  117. def subsample_labels(labels,
  118. num_samples,
  119. fg_fraction,
  120. bg_label=0,
  121. use_random=True):
  122. positive = paddle.nonzero(
  123. paddle.logical_and(labels != -1, labels != bg_label))
  124. negative = paddle.nonzero(labels == bg_label)
  125. fg_num = int(num_samples * fg_fraction)
  126. fg_num = min(positive.numel(), fg_num)
  127. bg_num = num_samples - fg_num
  128. bg_num = min(negative.numel(), bg_num)
  129. if fg_num == 0 and bg_num == 0:
  130. fg_inds = paddle.zeros([0], dtype='int32')
  131. bg_inds = paddle.zeros([0], dtype='int32')
  132. return fg_inds, bg_inds
  133. # randomly select positive and negative examples
  134. negative = negative.cast('int32').flatten()
  135. bg_perm = paddle.randperm(negative.numel(), dtype='int32')
  136. bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
  137. if use_random:
  138. bg_inds = paddle.gather(negative, bg_perm)
  139. else:
  140. bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
  141. if fg_num == 0:
  142. fg_inds = paddle.zeros([0], dtype='int32')
  143. return fg_inds, bg_inds
  144. positive = positive.cast('int32').flatten()
  145. fg_perm = paddle.randperm(positive.numel(), dtype='int32')
  146. fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
  147. if use_random:
  148. fg_inds = paddle.gather(positive, fg_perm)
  149. else:
  150. fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
  151. return fg_inds, bg_inds
  152. def generate_proposal_target(rpn_rois,
  153. gt_classes,
  154. gt_boxes,
  155. batch_size_per_im,
  156. fg_fraction,
  157. fg_thresh,
  158. bg_thresh,
  159. num_classes,
  160. ignore_thresh=-1.,
  161. is_crowd=None,
  162. use_random=True,
  163. is_cascade=False,
  164. cascade_iou=0.5):
  165. rois_with_gt = []
  166. tgt_labels = []
  167. tgt_bboxes = []
  168. tgt_gt_inds = []
  169. new_rois_num = []
  170. # In cascade rcnn, the threshold for foreground and background
  171. # is used from cascade_iou
  172. fg_thresh = cascade_iou if is_cascade else fg_thresh
  173. bg_thresh = cascade_iou if is_cascade else bg_thresh
  174. for i, rpn_roi in enumerate(rpn_rois):
  175. gt_bbox = gt_boxes[i]
  176. is_crowd_i = is_crowd[i] if is_crowd else None
  177. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  178. # Concat RoIs and gt boxes except cascade rcnn or none gt
  179. if not is_cascade and gt_bbox.shape[0] > 0:
  180. bbox = paddle.concat([rpn_roi, gt_bbox])
  181. else:
  182. bbox = rpn_roi
  183. # Step1: label bbox
  184. matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
  185. False, ignore_thresh, is_crowd_i)
  186. # Step2: sample bbox
  187. sampled_inds, sampled_gt_classes = sample_bbox(
  188. matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
  189. num_classes, use_random, is_cascade)
  190. # Step3: make output
  191. rois_per_image = bbox if is_cascade else paddle.gather(bbox,
  192. sampled_inds)
  193. sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
  194. sampled_inds)
  195. if gt_bbox.shape[0] > 0:
  196. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  197. else:
  198. sampled_bbox = paddle.zeros([0, 4], dtype='float32')
  199. rois_per_image.stop_gradient = True
  200. sampled_gt_ind.stop_gradient = True
  201. sampled_bbox.stop_gradient = True
  202. tgt_labels.append(sampled_gt_classes)
  203. tgt_bboxes.append(sampled_bbox)
  204. rois_with_gt.append(rois_per_image)
  205. tgt_gt_inds.append(sampled_gt_ind)
  206. new_rois_num.append(paddle.shape(sampled_inds)[0])
  207. new_rois_num = paddle.concat(new_rois_num)
  208. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  209. def sample_bbox(matches,
  210. match_labels,
  211. gt_classes,
  212. batch_size_per_im,
  213. fg_fraction,
  214. num_classes,
  215. use_random=True,
  216. is_cascade=False):
  217. n_gt = gt_classes.shape[0]
  218. if n_gt == 0:
  219. # No truth, assign everything to background
  220. gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
  221. #return matches, match_labels + num_classes
  222. else:
  223. gt_classes = paddle.gather(gt_classes, matches)
  224. gt_classes = paddle.where(match_labels == 0,
  225. paddle.ones_like(gt_classes) * num_classes,
  226. gt_classes)
  227. gt_classes = paddle.where(match_labels == -1,
  228. paddle.ones_like(gt_classes) * -1,
  229. gt_classes)
  230. if is_cascade:
  231. index = paddle.arange(matches.shape[0])
  232. return index, gt_classes
  233. rois_per_image = int(batch_size_per_im)
  234. fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image,
  235. fg_fraction, num_classes, use_random)
  236. if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
  237. # fake output labeled with -1 when all boxes are neither
  238. # foreground nor background
  239. sampled_inds = paddle.zeros([1], dtype='int32')
  240. else:
  241. sampled_inds = paddle.concat([fg_inds, bg_inds])
  242. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  243. return sampled_inds, sampled_gt_classes
  244. def polygons_to_mask(polygons, height, width):
  245. """
  246. Args:
  247. polygons (list[ndarray]): each array has shape (Nx2,)
  248. height, width (int)
  249. Returns:
  250. ndarray: a bool mask of shape (height, width)
  251. """
  252. import pycocotools.mask as mask_util
  253. assert len(polygons) > 0, "COCOAPI does not support empty polygons"
  254. rles = mask_util.frPyObjects(polygons, height, width)
  255. rle = mask_util.merge(rles)
  256. return mask_util.decode(rle).astype(np.bool)
  257. def rasterize_polygons_within_box(poly, box, resolution):
  258. w, h = box[2] - box[0], box[3] - box[1]
  259. polygons = [np.asarray(p, dtype=np.float64) for p in poly]
  260. for p in polygons:
  261. p[0::2] = p[0::2] - box[0]
  262. p[1::2] = p[1::2] - box[1]
  263. ratio_h = resolution / max(h, 0.1)
  264. ratio_w = resolution / max(w, 0.1)
  265. if ratio_h == ratio_w:
  266. for p in polygons:
  267. p *= ratio_h
  268. else:
  269. for p in polygons:
  270. p[0::2] *= ratio_w
  271. p[1::2] *= ratio_h
  272. # 3. Rasterize the polygons with coco api
  273. mask = polygons_to_mask(polygons, resolution, resolution)
  274. mask = paddle.to_tensor(mask, dtype='int32')
  275. return mask
  276. def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
  277. num_classes, resolution):
  278. mask_rois = []
  279. mask_rois_num = []
  280. tgt_masks = []
  281. tgt_classes = []
  282. mask_index = []
  283. tgt_weights = []
  284. for k in range(len(rois)):
  285. labels_per_im = labels_int32[k]
  286. # select rois labeled with foreground
  287. fg_inds = paddle.nonzero(
  288. paddle.logical_and(labels_per_im != -1, labels_per_im !=
  289. num_classes))
  290. has_fg = True
  291. # generate fake roi if foreground is empty
  292. if fg_inds.numel() == 0:
  293. has_fg = False
  294. fg_inds = paddle.ones([1], dtype='int32')
  295. inds_per_im = sampled_gt_inds[k]
  296. inds_per_im = paddle.gather(inds_per_im, fg_inds)
  297. rois_per_im = rois[k]
  298. fg_rois = paddle.gather(rois_per_im, fg_inds)
  299. # Copy the foreground roi to cpu
  300. # to generate mask target with ground-truth
  301. boxes = fg_rois.numpy()
  302. gt_segms_per_im = gt_segms[k]
  303. new_segm = []
  304. inds_per_im = inds_per_im.numpy()
  305. if len(gt_segms_per_im) > 0:
  306. for i in inds_per_im:
  307. new_segm.append(gt_segms_per_im[i])
  308. fg_inds_new = fg_inds.reshape([-1]).numpy()
  309. results = []
  310. if len(gt_segms_per_im) > 0:
  311. for j in fg_inds_new:
  312. results.append(
  313. rasterize_polygons_within_box(new_segm[j], boxes[j],
  314. resolution))
  315. else:
  316. results.append(
  317. paddle.ones(
  318. [resolution, resolution], dtype='int32'))
  319. fg_classes = paddle.gather(labels_per_im, fg_inds)
  320. weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
  321. if not has_fg:
  322. # now all sampled classes are background
  323. # which will cause error in loss calculation,
  324. # make fake classes with weight of 0.
  325. fg_classes = paddle.zeros([1], dtype='int32')
  326. weight = weight - 1
  327. tgt_mask = paddle.stack(results)
  328. tgt_mask.stop_gradient = True
  329. fg_rois.stop_gradient = True
  330. mask_index.append(fg_inds)
  331. mask_rois.append(fg_rois)
  332. mask_rois_num.append(paddle.shape(fg_rois)[0])
  333. tgt_classes.append(fg_classes)
  334. tgt_masks.append(tgt_mask)
  335. tgt_weights.append(weight)
  336. mask_index = paddle.concat(mask_index)
  337. mask_rois_num = paddle.concat(mask_rois_num)
  338. tgt_classes = paddle.concat(tgt_classes, axis=0)
  339. tgt_masks = paddle.concat(tgt_masks, axis=0)
  340. tgt_weights = paddle.concat(tgt_weights, axis=0)
  341. return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  342. def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
  343. if len(pos_inds) <= num_expected:
  344. return pos_inds
  345. else:
  346. unique_gt_inds = np.unique(max_classes[pos_inds])
  347. num_gts = len(unique_gt_inds)
  348. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  349. sampled_inds = []
  350. for i in unique_gt_inds:
  351. inds = np.nonzero(max_classes == i)[0]
  352. before_len = len(inds)
  353. inds = list(set(inds) & set(pos_inds))
  354. after_len = len(inds)
  355. if len(inds) > num_per_gt:
  356. inds = np.random.choice(inds, size=num_per_gt, replace=False)
  357. sampled_inds.extend(list(inds)) # combine as a new sampler
  358. if len(sampled_inds) < num_expected:
  359. num_extra = num_expected - len(sampled_inds)
  360. extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
  361. assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
  362. "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
  363. len(sampled_inds), len(extra_inds), len(pos_inds))
  364. if len(extra_inds) > num_extra:
  365. extra_inds = np.random.choice(
  366. extra_inds, size=num_extra, replace=False)
  367. sampled_inds.extend(extra_inds.tolist())
  368. elif len(sampled_inds) > num_expected:
  369. sampled_inds = np.random.choice(
  370. sampled_inds, size=num_expected, replace=False)
  371. return paddle.to_tensor(sampled_inds)
  372. def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
  373. num_bins, bg_thresh):
  374. max_iou = max_overlaps.max()
  375. iou_interval = (max_iou - floor_thr) / num_bins
  376. per_num_expected = int(num_expected / num_bins)
  377. sampled_inds = []
  378. for i in range(num_bins):
  379. start_iou = floor_thr + i * iou_interval
  380. end_iou = floor_thr + (i + 1) * iou_interval
  381. tmp_set = set(
  382. np.where(
  383. np.logical_and(max_overlaps >= start_iou, max_overlaps <
  384. end_iou))[0])
  385. tmp_inds = list(tmp_set & full_set)
  386. if len(tmp_inds) > per_num_expected:
  387. tmp_sampled_set = np.random.choice(
  388. tmp_inds, size=per_num_expected, replace=False)
  389. else:
  390. tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
  391. sampled_inds.append(tmp_sampled_set)
  392. sampled_inds = np.concatenate(sampled_inds)
  393. if len(sampled_inds) < num_expected:
  394. num_extra = num_expected - len(sampled_inds)
  395. extra_inds = np.array(list(full_set - set(sampled_inds)))
  396. assert len(sampled_inds) + len(extra_inds) == len(full_set), \
  397. "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
  398. len(sampled_inds), len(extra_inds), len(full_set))
  399. if len(extra_inds) > num_extra:
  400. extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
  401. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  402. return sampled_inds
  403. def libra_sample_neg(max_overlaps,
  404. max_classes,
  405. neg_inds,
  406. num_expected,
  407. floor_thr=-1,
  408. floor_fraction=0,
  409. num_bins=3,
  410. bg_thresh=0.5):
  411. if len(neg_inds) <= num_expected:
  412. return neg_inds
  413. else:
  414. # balance sampling for negative samples
  415. neg_set = set(neg_inds.tolist())
  416. if floor_thr > 0:
  417. floor_set = set(
  418. np.where(
  419. np.logical_and(max_overlaps >= 0, max_overlaps <
  420. floor_thr))[0])
  421. iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
  422. elif floor_thr == 0:
  423. floor_set = set(np.where(max_overlaps == 0)[0])
  424. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  425. else:
  426. floor_set = set()
  427. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  428. floor_thr = 0
  429. floor_neg_inds = list(floor_set & neg_set)
  430. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  431. num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
  432. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  433. if num_bins >= 2:
  434. iou_sampled_inds = libra_sample_via_interval(
  435. max_overlaps,
  436. set(iou_sampling_neg_inds), num_expected_iou_sampling,
  437. floor_thr, num_bins, bg_thresh)
  438. else:
  439. iou_sampled_inds = np.random.choice(
  440. iou_sampling_neg_inds,
  441. size=num_expected_iou_sampling,
  442. replace=False)
  443. else:
  444. iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
  445. num_expected_floor = num_expected - len(iou_sampled_inds)
  446. if len(floor_neg_inds) > num_expected_floor:
  447. sampled_floor_inds = np.random.choice(
  448. floor_neg_inds, size=num_expected_floor, replace=False)
  449. else:
  450. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
  451. sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
  452. if len(sampled_inds) < num_expected:
  453. num_extra = num_expected - len(sampled_inds)
  454. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  455. if len(extra_inds) > num_extra:
  456. extra_inds = np.random.choice(
  457. extra_inds, size=num_extra, replace=False)
  458. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  459. return paddle.to_tensor(sampled_inds)
  460. def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
  461. negative_overlap, num_classes):
  462. # TODO: use paddle API to speed up
  463. gt_classes = gt_classes.numpy()
  464. gt_overlaps = np.zeros((anchors.shape[0], num_classes))
  465. matches = np.zeros((anchors.shape[0]), dtype=np.int32)
  466. if len(gt_boxes) > 0:
  467. proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
  468. overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
  469. overlaps_max = proposal_to_gt_overlaps.max(axis=1)
  470. # Boxes which with non-zero overlap with gt boxes
  471. overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
  472. overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
  473. overlapped_boxes_ind]]
  474. for idx in range(len(overlapped_boxes_ind)):
  475. gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
  476. idx]] = overlaps_max[overlapped_boxes_ind[idx]]
  477. matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
  478. overlapped_boxes_ind[idx]]
  479. gt_overlaps = paddle.to_tensor(gt_overlaps)
  480. matches = paddle.to_tensor(matches)
  481. matched_vals = paddle.max(gt_overlaps, axis=1)
  482. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  483. match_labels = paddle.where(matched_vals < negative_overlap,
  484. paddle.zeros_like(match_labels), match_labels)
  485. match_labels = paddle.where(matched_vals >= positive_overlap,
  486. paddle.ones_like(match_labels), match_labels)
  487. return matches, match_labels, matched_vals
  488. def libra_sample_bbox(matches,
  489. match_labels,
  490. matched_vals,
  491. gt_classes,
  492. batch_size_per_im,
  493. num_classes,
  494. fg_fraction,
  495. fg_thresh,
  496. bg_thresh,
  497. num_bins,
  498. use_random=True,
  499. is_cascade_rcnn=False):
  500. rois_per_image = int(batch_size_per_im)
  501. fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
  502. bg_rois_per_im = rois_per_image - fg_rois_per_im
  503. if is_cascade_rcnn:
  504. fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
  505. bg_inds = paddle.nonzero(matched_vals < bg_thresh)
  506. else:
  507. matched_vals_np = matched_vals.numpy()
  508. match_labels_np = match_labels.numpy()
  509. # sample fg
  510. fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
  511. fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
  512. if (fg_inds.shape[0] > fg_nums) and use_random:
  513. fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
  514. fg_inds.numpy(), fg_rois_per_im)
  515. fg_inds = fg_inds[:fg_nums]
  516. # sample bg
  517. bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
  518. bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
  519. if (bg_inds.shape[0] > bg_nums) and use_random:
  520. bg_inds = libra_sample_neg(
  521. matched_vals_np,
  522. match_labels_np,
  523. bg_inds.numpy(),
  524. bg_rois_per_im,
  525. num_bins=num_bins,
  526. bg_thresh=bg_thresh)
  527. bg_inds = bg_inds[:bg_nums]
  528. sampled_inds = paddle.concat([fg_inds, bg_inds])
  529. gt_classes = paddle.gather(gt_classes, matches)
  530. gt_classes = paddle.where(match_labels == 0,
  531. paddle.ones_like(gt_classes) * num_classes,
  532. gt_classes)
  533. gt_classes = paddle.where(match_labels == -1,
  534. paddle.ones_like(gt_classes) * -1,
  535. gt_classes)
  536. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  537. return sampled_inds, sampled_gt_classes
  538. def libra_generate_proposal_target(rpn_rois,
  539. gt_classes,
  540. gt_boxes,
  541. batch_size_per_im,
  542. fg_fraction,
  543. fg_thresh,
  544. bg_thresh,
  545. num_classes,
  546. use_random=True,
  547. is_cascade_rcnn=False,
  548. max_overlaps=None,
  549. num_bins=3):
  550. rois_with_gt = []
  551. tgt_labels = []
  552. tgt_bboxes = []
  553. sampled_max_overlaps = []
  554. tgt_gt_inds = []
  555. new_rois_num = []
  556. for i, rpn_roi in enumerate(rpn_rois):
  557. max_overlap = max_overlaps[i] if is_cascade_rcnn else None
  558. gt_bbox = gt_boxes[i]
  559. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  560. if is_cascade_rcnn:
  561. rpn_roi = filter_roi(rpn_roi, max_overlap)
  562. bbox = paddle.concat([rpn_roi, gt_bbox])
  563. # Step1: label bbox
  564. matches, match_labels, matched_vals = libra_label_box(
  565. bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
  566. # Step2: sample bbox
  567. sampled_inds, sampled_gt_classes = libra_sample_bbox(
  568. matches, match_labels, matched_vals, gt_class, batch_size_per_im,
  569. num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
  570. use_random, is_cascade_rcnn)
  571. # Step3: make output
  572. rois_per_image = paddle.gather(bbox, sampled_inds)
  573. sampled_gt_ind = paddle.gather(matches, sampled_inds)
  574. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  575. sampled_overlap = paddle.gather(matched_vals, sampled_inds)
  576. rois_per_image.stop_gradient = True
  577. sampled_gt_ind.stop_gradient = True
  578. sampled_bbox.stop_gradient = True
  579. sampled_overlap.stop_gradient = True
  580. tgt_labels.append(sampled_gt_classes)
  581. tgt_bboxes.append(sampled_bbox)
  582. rois_with_gt.append(rois_per_image)
  583. sampled_max_overlaps.append(sampled_overlap)
  584. tgt_gt_inds.append(sampled_gt_ind)
  585. new_rois_num.append(paddle.shape(sampled_inds)[0])
  586. new_rois_num = paddle.concat(new_rois_num)
  587. # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  588. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num