target.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from ..bbox_utils import bbox2delta, bbox_overlaps
  17. def rpn_anchor_target(anchors,
  18. gt_boxes,
  19. rpn_batch_size_per_im,
  20. rpn_positive_overlap,
  21. rpn_negative_overlap,
  22. rpn_fg_fraction,
  23. use_random=True,
  24. batch_size=1,
  25. ignore_thresh=-1,
  26. is_crowd=None,
  27. weights=[1., 1., 1., 1.],
  28. assign_on_cpu=False):
  29. tgt_labels = []
  30. tgt_bboxes = []
  31. tgt_deltas = []
  32. for i in range(batch_size):
  33. gt_bbox = gt_boxes[i]
  34. is_crowd_i = is_crowd[i] if is_crowd else None
  35. # Step1: match anchor and gt_bbox
  36. matches, match_labels = label_box(
  37. anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
  38. ignore_thresh, is_crowd_i, assign_on_cpu)
  39. # Step2: sample anchor
  40. fg_inds, bg_inds = subsample_labels(match_labels,
  41. rpn_batch_size_per_im,
  42. rpn_fg_fraction, 0, use_random)
  43. # Fill with the ignore label (-1), then set positive and negative labels
  44. labels = paddle.full(match_labels.shape, -1, dtype='int32')
  45. if bg_inds.shape[0] > 0:
  46. labels = paddle.scatter(labels, bg_inds,
  47. paddle.zeros_like(bg_inds))
  48. if fg_inds.shape[0] > 0:
  49. labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
  50. # Step3: make output
  51. if gt_bbox.shape[0] == 0:
  52. matched_gt_boxes = paddle.zeros([matches.shape[0], 4])
  53. tgt_delta = paddle.zeros([matches.shape[0], 4])
  54. else:
  55. matched_gt_boxes = paddle.gather(gt_bbox, matches)
  56. tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
  57. matched_gt_boxes.stop_gradient = True
  58. tgt_delta.stop_gradient = True
  59. labels.stop_gradient = True
  60. tgt_labels.append(labels)
  61. tgt_bboxes.append(matched_gt_boxes)
  62. tgt_deltas.append(tgt_delta)
  63. return tgt_labels, tgt_bboxes, tgt_deltas
  64. def label_box(anchors,
  65. gt_boxes,
  66. positive_overlap,
  67. negative_overlap,
  68. allow_low_quality,
  69. ignore_thresh,
  70. is_crowd=None,
  71. assign_on_cpu=False):
  72. if assign_on_cpu:
  73. device = paddle.device.get_device()
  74. paddle.set_device("cpu")
  75. iou = bbox_overlaps(gt_boxes, anchors)
  76. paddle.set_device(device)
  77. else:
  78. iou = bbox_overlaps(gt_boxes, anchors)
  79. n_gt = gt_boxes.shape[0]
  80. if n_gt == 0 or is_crowd is None:
  81. n_gt_crowd = 0
  82. else:
  83. n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
  84. if iou.shape[0] == 0 or n_gt_crowd == n_gt:
  85. # No truth, assign everything to background
  86. default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
  87. default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
  88. return default_matches, default_match_labels
  89. # if ignore_thresh > 0, remove anchor if it is closed to
  90. # one of the crowded ground-truth
  91. if n_gt_crowd > 0:
  92. N_a = anchors.shape[0]
  93. ones = paddle.ones([N_a])
  94. mask = is_crowd * ones
  95. if ignore_thresh > 0:
  96. crowd_iou = iou * mask
  97. valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
  98. axis=0) > 0).cast('float32')
  99. iou = iou * (1 - valid) - valid
  100. # ignore the iou between anchor and crowded ground-truth
  101. iou = iou * (1 - mask) - mask
  102. matched_vals, matches = paddle.topk(iou, k=1, axis=0)
  103. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  104. # set ignored anchor with iou = -1
  105. neg_cond = paddle.logical_and(matched_vals > -1,
  106. matched_vals < negative_overlap)
  107. match_labels = paddle.where(neg_cond,
  108. paddle.zeros_like(match_labels), match_labels)
  109. match_labels = paddle.where(matched_vals >= positive_overlap,
  110. paddle.ones_like(match_labels), match_labels)
  111. if allow_low_quality:
  112. highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
  113. pred_inds_with_highest_quality = paddle.logical_and(
  114. iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
  115. 0, keepdim=True)
  116. match_labels = paddle.where(pred_inds_with_highest_quality > 0,
  117. paddle.ones_like(match_labels),
  118. match_labels)
  119. matches = matches.flatten()
  120. match_labels = match_labels.flatten()
  121. return matches, match_labels
  122. def subsample_labels(labels,
  123. num_samples,
  124. fg_fraction,
  125. bg_label=0,
  126. use_random=True):
  127. positive = paddle.nonzero(
  128. paddle.logical_and(labels != -1, labels != bg_label))
  129. negative = paddle.nonzero(labels == bg_label)
  130. fg_num = int(num_samples * fg_fraction)
  131. fg_num = min(positive.numel(), fg_num)
  132. bg_num = num_samples - fg_num
  133. bg_num = min(negative.numel(), bg_num)
  134. if fg_num == 0 and bg_num == 0:
  135. fg_inds = paddle.zeros([0], dtype='int32')
  136. bg_inds = paddle.zeros([0], dtype='int32')
  137. return fg_inds, bg_inds
  138. # randomly select positive and negative examples
  139. negative = negative.cast('int32').flatten()
  140. bg_perm = paddle.randperm(negative.numel(), dtype='int32')
  141. bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
  142. if use_random:
  143. bg_inds = paddle.gather(negative, bg_perm)
  144. else:
  145. bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
  146. if fg_num == 0:
  147. fg_inds = paddle.zeros([0], dtype='int32')
  148. return fg_inds, bg_inds
  149. positive = positive.cast('int32').flatten()
  150. fg_perm = paddle.randperm(positive.numel(), dtype='int32')
  151. fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
  152. if use_random:
  153. fg_inds = paddle.gather(positive, fg_perm)
  154. else:
  155. fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
  156. return fg_inds, bg_inds
  157. def generate_proposal_target(rpn_rois,
  158. gt_classes,
  159. gt_boxes,
  160. batch_size_per_im,
  161. fg_fraction,
  162. fg_thresh,
  163. bg_thresh,
  164. num_classes,
  165. ignore_thresh=-1.,
  166. is_crowd=None,
  167. use_random=True,
  168. is_cascade=False,
  169. cascade_iou=0.5,
  170. assign_on_cpu=False):
  171. rois_with_gt = []
  172. tgt_labels = []
  173. tgt_bboxes = []
  174. tgt_gt_inds = []
  175. new_rois_num = []
  176. # In cascade rcnn, the threshold for foreground and background
  177. # is used from cascade_iou
  178. fg_thresh = cascade_iou if is_cascade else fg_thresh
  179. bg_thresh = cascade_iou if is_cascade else bg_thresh
  180. for i, rpn_roi in enumerate(rpn_rois):
  181. gt_bbox = gt_boxes[i]
  182. is_crowd_i = is_crowd[i] if is_crowd else None
  183. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  184. # Concat RoIs and gt boxes except cascade rcnn or none gt
  185. if not is_cascade and gt_bbox.shape[0] > 0:
  186. bbox = paddle.concat([rpn_roi, gt_bbox])
  187. else:
  188. bbox = rpn_roi
  189. # Step1: label bbox
  190. matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
  191. False, ignore_thresh, is_crowd_i,
  192. assign_on_cpu)
  193. # Step2: sample bbox
  194. sampled_inds, sampled_gt_classes = sample_bbox(
  195. matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
  196. num_classes, use_random, is_cascade)
  197. # Step3: make output
  198. rois_per_image = bbox if is_cascade else paddle.gather(bbox,
  199. sampled_inds)
  200. sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
  201. sampled_inds)
  202. if gt_bbox.shape[0] > 0:
  203. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  204. else:
  205. num = rois_per_image.shape[0]
  206. sampled_bbox = paddle.zeros([num, 4], dtype='float32')
  207. rois_per_image.stop_gradient = True
  208. sampled_gt_ind.stop_gradient = True
  209. sampled_bbox.stop_gradient = True
  210. tgt_labels.append(sampled_gt_classes)
  211. tgt_bboxes.append(sampled_bbox)
  212. rois_with_gt.append(rois_per_image)
  213. tgt_gt_inds.append(sampled_gt_ind)
  214. new_rois_num.append(paddle.shape(sampled_inds)[0])
  215. new_rois_num = paddle.concat(new_rois_num)
  216. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  217. def sample_bbox(matches,
  218. match_labels,
  219. gt_classes,
  220. batch_size_per_im,
  221. fg_fraction,
  222. num_classes,
  223. use_random=True,
  224. is_cascade=False):
  225. n_gt = gt_classes.shape[0]
  226. if n_gt == 0:
  227. # No truth, assign everything to background
  228. gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
  229. #return matches, match_labels + num_classes
  230. else:
  231. gt_classes = paddle.gather(gt_classes, matches)
  232. gt_classes = paddle.where(match_labels == 0,
  233. paddle.ones_like(gt_classes) * num_classes,
  234. gt_classes)
  235. gt_classes = paddle.where(match_labels == -1,
  236. paddle.ones_like(gt_classes) * -1,
  237. gt_classes)
  238. if is_cascade:
  239. index = paddle.arange(matches.shape[0])
  240. return index, gt_classes
  241. rois_per_image = int(batch_size_per_im)
  242. fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image,
  243. fg_fraction, num_classes, use_random)
  244. if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
  245. # fake output labeled with -1 when all boxes are neither
  246. # foreground nor background
  247. sampled_inds = paddle.zeros([1], dtype='int32')
  248. else:
  249. sampled_inds = paddle.concat([fg_inds, bg_inds])
  250. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  251. return sampled_inds, sampled_gt_classes
  252. def polygons_to_mask(polygons, height, width):
  253. """
  254. Convert the polygons to mask format
  255. Args:
  256. polygons (list[ndarray]): each array has shape (Nx2,)
  257. height (int): mask height
  258. width (int): mask width
  259. Returns:
  260. ndarray: a bool mask of shape (height, width)
  261. """
  262. import pycocotools.mask as mask_util
  263. assert len(polygons) > 0, "COCOAPI does not support empty polygons"
  264. rles = mask_util.frPyObjects(polygons, height, width)
  265. rle = mask_util.merge(rles)
  266. return mask_util.decode(rle).astype(np.bool)
  267. def rasterize_polygons_within_box(poly, box, resolution):
  268. w, h = box[2] - box[0], box[3] - box[1]
  269. polygons = [np.asarray(p, dtype=np.float64) for p in poly]
  270. for p in polygons:
  271. p[0::2] = p[0::2] - box[0]
  272. p[1::2] = p[1::2] - box[1]
  273. ratio_h = resolution / max(h, 0.1)
  274. ratio_w = resolution / max(w, 0.1)
  275. if ratio_h == ratio_w:
  276. for p in polygons:
  277. p *= ratio_h
  278. else:
  279. for p in polygons:
  280. p[0::2] *= ratio_w
  281. p[1::2] *= ratio_h
  282. # 3. Rasterize the polygons with coco api
  283. mask = polygons_to_mask(polygons, resolution, resolution)
  284. mask = paddle.to_tensor(mask, dtype='int32')
  285. return mask
  286. def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
  287. num_classes, resolution):
  288. mask_rois = []
  289. mask_rois_num = []
  290. tgt_masks = []
  291. tgt_classes = []
  292. mask_index = []
  293. tgt_weights = []
  294. for k in range(len(rois)):
  295. labels_per_im = labels_int32[k]
  296. # select rois labeled with foreground
  297. fg_inds = paddle.nonzero(
  298. paddle.logical_and(labels_per_im != -1, labels_per_im !=
  299. num_classes))
  300. has_fg = True
  301. # generate fake roi if foreground is empty
  302. if fg_inds.numel() == 0:
  303. has_fg = False
  304. fg_inds = paddle.ones([1], dtype='int32')
  305. inds_per_im = sampled_gt_inds[k]
  306. inds_per_im = paddle.gather(inds_per_im, fg_inds)
  307. rois_per_im = rois[k]
  308. fg_rois = paddle.gather(rois_per_im, fg_inds)
  309. # Copy the foreground roi to cpu
  310. # to generate mask target with ground-truth
  311. boxes = fg_rois.numpy()
  312. gt_segms_per_im = gt_segms[k]
  313. new_segm = []
  314. inds_per_im = inds_per_im.numpy()
  315. if len(gt_segms_per_im) > 0:
  316. for i in inds_per_im:
  317. new_segm.append(gt_segms_per_im[i])
  318. fg_inds_new = fg_inds.reshape([-1]).numpy()
  319. results = []
  320. if len(gt_segms_per_im) > 0:
  321. for j in range(fg_inds_new.shape[0]):
  322. results.append(
  323. rasterize_polygons_within_box(new_segm[j], boxes[j],
  324. resolution))
  325. else:
  326. results.append(
  327. paddle.ones(
  328. [resolution, resolution], dtype='int32'))
  329. fg_classes = paddle.gather(labels_per_im, fg_inds)
  330. weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
  331. if not has_fg:
  332. # now all sampled classes are background
  333. # which will cause error in loss calculation,
  334. # make fake classes with weight of 0.
  335. fg_classes = paddle.zeros([1], dtype='int32')
  336. weight = weight - 1
  337. tgt_mask = paddle.stack(results)
  338. tgt_mask.stop_gradient = True
  339. fg_rois.stop_gradient = True
  340. mask_index.append(fg_inds)
  341. mask_rois.append(fg_rois)
  342. mask_rois_num.append(paddle.shape(fg_rois)[0])
  343. tgt_classes.append(fg_classes)
  344. tgt_masks.append(tgt_mask)
  345. tgt_weights.append(weight)
  346. mask_index = paddle.concat(mask_index)
  347. mask_rois_num = paddle.concat(mask_rois_num)
  348. tgt_classes = paddle.concat(tgt_classes, axis=0)
  349. tgt_masks = paddle.concat(tgt_masks, axis=0)
  350. tgt_weights = paddle.concat(tgt_weights, axis=0)
  351. return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  352. def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
  353. if len(pos_inds) <= num_expected:
  354. return pos_inds
  355. else:
  356. unique_gt_inds = np.unique(max_classes[pos_inds])
  357. num_gts = len(unique_gt_inds)
  358. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  359. sampled_inds = []
  360. for i in unique_gt_inds:
  361. inds = np.nonzero(max_classes == i)[0]
  362. before_len = len(inds)
  363. inds = list(set(inds) & set(pos_inds))
  364. after_len = len(inds)
  365. if len(inds) > num_per_gt:
  366. inds = np.random.choice(inds, size=num_per_gt, replace=False)
  367. sampled_inds.extend(list(inds)) # combine as a new sampler
  368. if len(sampled_inds) < num_expected:
  369. num_extra = num_expected - len(sampled_inds)
  370. extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
  371. assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
  372. "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
  373. len(sampled_inds), len(extra_inds), len(pos_inds))
  374. if len(extra_inds) > num_extra:
  375. extra_inds = np.random.choice(
  376. extra_inds, size=num_extra, replace=False)
  377. sampled_inds.extend(extra_inds.tolist())
  378. elif len(sampled_inds) > num_expected:
  379. sampled_inds = np.random.choice(
  380. sampled_inds, size=num_expected, replace=False)
  381. return paddle.to_tensor(sampled_inds)
  382. def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
  383. num_bins, bg_thresh):
  384. max_iou = max_overlaps.max()
  385. iou_interval = (max_iou - floor_thr) / num_bins
  386. per_num_expected = int(num_expected / num_bins)
  387. sampled_inds = []
  388. for i in range(num_bins):
  389. start_iou = floor_thr + i * iou_interval
  390. end_iou = floor_thr + (i + 1) * iou_interval
  391. tmp_set = set(
  392. np.where(
  393. np.logical_and(max_overlaps >= start_iou, max_overlaps <
  394. end_iou))[0])
  395. tmp_inds = list(tmp_set & full_set)
  396. if len(tmp_inds) > per_num_expected:
  397. tmp_sampled_set = np.random.choice(
  398. tmp_inds, size=per_num_expected, replace=False)
  399. else:
  400. tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
  401. sampled_inds.append(tmp_sampled_set)
  402. sampled_inds = np.concatenate(sampled_inds)
  403. if len(sampled_inds) < num_expected:
  404. num_extra = num_expected - len(sampled_inds)
  405. extra_inds = np.array(list(full_set - set(sampled_inds)))
  406. assert len(sampled_inds) + len(extra_inds) == len(full_set), \
  407. "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
  408. len(sampled_inds), len(extra_inds), len(full_set))
  409. if len(extra_inds) > num_extra:
  410. extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
  411. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  412. return sampled_inds
  413. def libra_sample_neg(max_overlaps,
  414. max_classes,
  415. neg_inds,
  416. num_expected,
  417. floor_thr=-1,
  418. floor_fraction=0,
  419. num_bins=3,
  420. bg_thresh=0.5):
  421. if len(neg_inds) <= num_expected:
  422. return neg_inds
  423. else:
  424. # balance sampling for negative samples
  425. neg_set = set(neg_inds.tolist())
  426. if floor_thr > 0:
  427. floor_set = set(
  428. np.where(
  429. np.logical_and(max_overlaps >= 0, max_overlaps <
  430. floor_thr))[0])
  431. iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
  432. elif floor_thr == 0:
  433. floor_set = set(np.where(max_overlaps == 0)[0])
  434. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  435. else:
  436. floor_set = set()
  437. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  438. floor_thr = 0
  439. floor_neg_inds = list(floor_set & neg_set)
  440. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  441. num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
  442. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  443. if num_bins >= 2:
  444. iou_sampled_inds = libra_sample_via_interval(
  445. max_overlaps,
  446. set(iou_sampling_neg_inds), num_expected_iou_sampling,
  447. floor_thr, num_bins, bg_thresh)
  448. else:
  449. iou_sampled_inds = np.random.choice(
  450. iou_sampling_neg_inds,
  451. size=num_expected_iou_sampling,
  452. replace=False)
  453. else:
  454. iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
  455. num_expected_floor = num_expected - len(iou_sampled_inds)
  456. if len(floor_neg_inds) > num_expected_floor:
  457. sampled_floor_inds = np.random.choice(
  458. floor_neg_inds, size=num_expected_floor, replace=False)
  459. else:
  460. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
  461. sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
  462. if len(sampled_inds) < num_expected:
  463. num_extra = num_expected - len(sampled_inds)
  464. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  465. if len(extra_inds) > num_extra:
  466. extra_inds = np.random.choice(
  467. extra_inds, size=num_extra, replace=False)
  468. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  469. return paddle.to_tensor(sampled_inds)
  470. def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
  471. negative_overlap, num_classes):
  472. # TODO: use paddle API to speed up
  473. gt_classes = gt_classes.numpy()
  474. gt_overlaps = np.zeros((anchors.shape[0], num_classes))
  475. matches = np.zeros((anchors.shape[0]), dtype=np.int32)
  476. if len(gt_boxes) > 0:
  477. proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
  478. overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
  479. overlaps_max = proposal_to_gt_overlaps.max(axis=1)
  480. # Boxes which with non-zero overlap with gt boxes
  481. overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
  482. overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
  483. overlapped_boxes_ind]]
  484. for idx in range(len(overlapped_boxes_ind)):
  485. gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
  486. idx]] = overlaps_max[overlapped_boxes_ind[idx]]
  487. matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
  488. overlapped_boxes_ind[idx]]
  489. gt_overlaps = paddle.to_tensor(gt_overlaps)
  490. matches = paddle.to_tensor(matches)
  491. matched_vals = paddle.max(gt_overlaps, axis=1)
  492. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  493. match_labels = paddle.where(matched_vals < negative_overlap,
  494. paddle.zeros_like(match_labels), match_labels)
  495. match_labels = paddle.where(matched_vals >= positive_overlap,
  496. paddle.ones_like(match_labels), match_labels)
  497. return matches, match_labels, matched_vals
  498. def libra_sample_bbox(matches,
  499. match_labels,
  500. matched_vals,
  501. gt_classes,
  502. batch_size_per_im,
  503. num_classes,
  504. fg_fraction,
  505. fg_thresh,
  506. bg_thresh,
  507. num_bins,
  508. use_random=True,
  509. is_cascade_rcnn=False):
  510. rois_per_image = int(batch_size_per_im)
  511. fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
  512. bg_rois_per_im = rois_per_image - fg_rois_per_im
  513. if is_cascade_rcnn:
  514. fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
  515. bg_inds = paddle.nonzero(matched_vals < bg_thresh)
  516. else:
  517. matched_vals_np = matched_vals.numpy()
  518. match_labels_np = match_labels.numpy()
  519. # sample fg
  520. fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
  521. fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
  522. if (fg_inds.shape[0] > fg_nums) and use_random:
  523. fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
  524. fg_inds.numpy(), fg_rois_per_im)
  525. fg_inds = fg_inds[:fg_nums]
  526. # sample bg
  527. bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
  528. bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
  529. if (bg_inds.shape[0] > bg_nums) and use_random:
  530. bg_inds = libra_sample_neg(
  531. matched_vals_np,
  532. match_labels_np,
  533. bg_inds.numpy(),
  534. bg_rois_per_im,
  535. num_bins=num_bins,
  536. bg_thresh=bg_thresh)
  537. bg_inds = bg_inds[:bg_nums]
  538. sampled_inds = paddle.concat([fg_inds, bg_inds])
  539. gt_classes = paddle.gather(gt_classes, matches)
  540. gt_classes = paddle.where(match_labels == 0,
  541. paddle.ones_like(gt_classes) * num_classes,
  542. gt_classes)
  543. gt_classes = paddle.where(match_labels == -1,
  544. paddle.ones_like(gt_classes) * -1,
  545. gt_classes)
  546. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  547. return sampled_inds, sampled_gt_classes
  548. def libra_generate_proposal_target(rpn_rois,
  549. gt_classes,
  550. gt_boxes,
  551. batch_size_per_im,
  552. fg_fraction,
  553. fg_thresh,
  554. bg_thresh,
  555. num_classes,
  556. use_random=True,
  557. is_cascade_rcnn=False,
  558. max_overlaps=None,
  559. num_bins=3):
  560. rois_with_gt = []
  561. tgt_labels = []
  562. tgt_bboxes = []
  563. sampled_max_overlaps = []
  564. tgt_gt_inds = []
  565. new_rois_num = []
  566. for i, rpn_roi in enumerate(rpn_rois):
  567. max_overlap = max_overlaps[i] if is_cascade_rcnn else None
  568. gt_bbox = gt_boxes[i]
  569. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  570. if is_cascade_rcnn:
  571. rpn_roi = filter_roi(rpn_roi, max_overlap)
  572. bbox = paddle.concat([rpn_roi, gt_bbox])
  573. # Step1: label bbox
  574. matches, match_labels, matched_vals = libra_label_box(
  575. bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
  576. # Step2: sample bbox
  577. sampled_inds, sampled_gt_classes = libra_sample_bbox(
  578. matches, match_labels, matched_vals, gt_class, batch_size_per_im,
  579. num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
  580. use_random, is_cascade_rcnn)
  581. # Step3: make output
  582. rois_per_image = paddle.gather(bbox, sampled_inds)
  583. sampled_gt_ind = paddle.gather(matches, sampled_inds)
  584. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  585. sampled_overlap = paddle.gather(matched_vals, sampled_inds)
  586. rois_per_image.stop_gradient = True
  587. sampled_gt_ind.stop_gradient = True
  588. sampled_bbox.stop_gradient = True
  589. sampled_overlap.stop_gradient = True
  590. tgt_labels.append(sampled_gt_classes)
  591. tgt_bboxes.append(sampled_bbox)
  592. rois_with_gt.append(rois_per_image)
  593. sampled_max_overlaps.append(sampled_overlap)
  594. tgt_gt_inds.append(sampled_gt_ind)
  595. new_rois_num.append(paddle.shape(sampled_inds)[0])
  596. new_rois_num = paddle.concat(new_rois_num)
  597. # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  598. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num