post_process.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register
  19. from paddlex.ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly
  20. from paddlex.ppdet.modeling.layers import TTFBox
  21. from .transformers import bbox_cxcywh_to_xyxy
  22. try:
  23. from collections.abc import Sequence
  24. except Exception:
  25. from collections import Sequence
  26. __all__ = [
  27. 'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
  28. 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
  29. 'DETRBBoxPostProcess', 'SparsePostProcess'
  30. ]
  31. @register
  32. class BBoxPostProcess(object):
  33. __shared__ = ['num_classes', 'export_onnx']
  34. __inject__ = ['decode', 'nms']
  35. def __init__(self,
  36. num_classes=80,
  37. decode=None,
  38. nms=None,
  39. export_onnx=False):
  40. super(BBoxPostProcess, self).__init__()
  41. self.num_classes = num_classes
  42. self.decode = decode
  43. self.nms = nms
  44. self.export_onnx = export_onnx
  45. def __call__(self, head_out, rois, im_shape, scale_factor):
  46. """
  47. Decode the bbox and do NMS if needed.
  48. Args:
  49. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  50. rois (tuple): roi and rois_num of rpn_head output.
  51. im_shape (Tensor): The shape of the input image.
  52. scale_factor (Tensor): The scale factor of the input image.
  53. export_onnx (bool): whether export model to onnx
  54. Returns:
  55. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  56. labels, scores and bboxes. The size of bboxes are corresponding
  57. to the input image, the bboxes may be used in other branch.
  58. bbox_num (Tensor): The number of prediction boxes of each batch with
  59. shape [1], and is N.
  60. """
  61. if self.nms is not None:
  62. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  63. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  64. else:
  65. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  66. scale_factor)
  67. if self.export_onnx:
  68. # add fake box after postprocess when exporting onnx
  69. fake_bboxes = paddle.to_tensor(
  70. np.array(
  71. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  72. bbox_pred = paddle.concat([bbox_pred, fake_bboxes])
  73. bbox_num = bbox_num + 1
  74. return bbox_pred, bbox_num
  75. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  76. """
  77. Rescale, clip and filter the bbox from the output of NMS to
  78. get final prediction.
  79. Notes:
  80. Currently only support bs = 1.
  81. Args:
  82. bboxes (Tensor): The output bboxes with shape [N, 6] after decode
  83. and NMS, including labels, scores and bboxes.
  84. bbox_num (Tensor): The number of prediction boxes of each batch with
  85. shape [1], and is N.
  86. im_shape (Tensor): The shape of the input image.
  87. scale_factor (Tensor): The scale factor of the input image.
  88. Returns:
  89. pred_result (Tensor): The final prediction results with shape [N, 6]
  90. including labels, scores and bboxes.
  91. """
  92. if not self.export_onnx:
  93. bboxes_list = []
  94. bbox_num_list = []
  95. id_start = 0
  96. fake_bboxes = paddle.to_tensor(
  97. np.array(
  98. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  99. fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  100. # add fake bbox when output is empty for each batch
  101. for i in range(bbox_num.shape[0]):
  102. if bbox_num[i] == 0:
  103. bboxes_i = fake_bboxes
  104. bbox_num_i = fake_bbox_num
  105. else:
  106. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  107. bbox_num_i = bbox_num[i]
  108. id_start += bbox_num[i]
  109. bboxes_list.append(bboxes_i)
  110. bbox_num_list.append(bbox_num_i)
  111. bboxes = paddle.concat(bboxes_list)
  112. bbox_num = paddle.concat(bbox_num_list)
  113. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  114. if not self.export_onnx:
  115. origin_shape_list = []
  116. scale_factor_list = []
  117. # scale_factor: scale_y, scale_x
  118. for i in range(bbox_num.shape[0]):
  119. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  120. [bbox_num[i], 2])
  121. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  122. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  123. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  124. origin_shape_list.append(expand_shape)
  125. scale_factor_list.append(expand_scale)
  126. self.origin_shape_list = paddle.concat(origin_shape_list)
  127. scale_factor_list = paddle.concat(scale_factor_list)
  128. else:
  129. # simplify the computation for bs=1 when exporting onnx
  130. scale_y, scale_x = scale_factor[0][0], scale_factor[0][1]
  131. scale = paddle.concat(
  132. [scale_x, scale_y, scale_x, scale_y]).unsqueeze(0)
  133. self.origin_shape_list = paddle.expand(origin_shape,
  134. [bbox_num[0], 2])
  135. scale_factor_list = paddle.expand(scale, [bbox_num[0], 4])
  136. # bboxes: [N, 6], label, score, bbox
  137. pred_label = bboxes[:, 0:1]
  138. pred_score = bboxes[:, 1:2]
  139. pred_bbox = bboxes[:, 2:]
  140. # rescale bbox to original image
  141. scaled_bbox = pred_bbox / scale_factor_list
  142. origin_h = self.origin_shape_list[:, 0]
  143. origin_w = self.origin_shape_list[:, 1]
  144. zeros = paddle.zeros_like(origin_h)
  145. # clip bbox to [0, original_size]
  146. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  147. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  148. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  149. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  150. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  151. # filter empty bbox
  152. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  153. keep_mask = paddle.unsqueeze(keep_mask, [1])
  154. pred_label = paddle.where(keep_mask, pred_label,
  155. paddle.ones_like(pred_label) * -1)
  156. pred_result = paddle.concat(
  157. [pred_label, pred_score, pred_bbox], axis=1)
  158. return bboxes, pred_result, bbox_num
  159. def get_origin_shape(self, ):
  160. return self.origin_shape_list
  161. @register
  162. class MaskPostProcess(object):
  163. __shared__ = ['export_onnx', 'assign_on_cpu']
  164. """
  165. refer to:
  166. https://github.com/facebookresearch/detectron2/layers/mask_ops.py
  167. Get Mask output according to the output from model
  168. """
  169. def __init__(self,
  170. binary_thresh=0.5,
  171. export_onnx=False,
  172. assign_on_cpu=False):
  173. super(MaskPostProcess, self).__init__()
  174. self.binary_thresh = binary_thresh
  175. self.export_onnx = export_onnx
  176. self.assign_on_cpu = assign_on_cpu
  177. def paste_mask(self, masks, boxes, im_h, im_w):
  178. """
  179. Paste the mask prediction to the original image.
  180. """
  181. x0_int, y0_int = 0, 0
  182. x1_int, y1_int = im_w, im_h
  183. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  184. N = masks.shape[0]
  185. img_y = paddle.arange(y0_int, y1_int) + 0.5
  186. img_x = paddle.arange(x0_int, x1_int) + 0.5
  187. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  188. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  189. # img_x, img_y have shapes (N, w), (N, h)
  190. if self.assign_on_cpu:
  191. paddle.set_device('cpu')
  192. gx = img_x[:, None, :].expand(
  193. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  194. gy = img_y[:, :, None].expand(
  195. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  196. grid = paddle.stack([gx, gy], axis=3)
  197. img_masks = F.grid_sample(masks, grid, align_corners=False)
  198. return img_masks[:, 0]
  199. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  200. """
  201. Decode the mask_out and paste the mask to the origin image.
  202. Args:
  203. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  204. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  205. and NMS, including labels, scores and bboxes.
  206. bbox_num (Tensor): The number of prediction boxes of each batch with
  207. shape [1], and is N.
  208. origin_shape (Tensor): The origin shape of the input image, the tensor
  209. shape is [N, 2], and each row is [h, w].
  210. Returns:
  211. pred_result (Tensor): The final prediction mask results with shape
  212. [N, h, w] in binary mask style.
  213. """
  214. num_mask = mask_out.shape[0]
  215. origin_shape = paddle.cast(origin_shape, 'int32')
  216. device = paddle.device.get_device()
  217. if self.export_onnx:
  218. h, w = origin_shape[0][0], origin_shape[0][1]
  219. mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
  220. h, w)
  221. mask_onnx = mask_onnx >= self.binary_thresh
  222. pred_result = paddle.cast(mask_onnx, 'int32')
  223. else:
  224. max_h = paddle.max(origin_shape[:, 0])
  225. max_w = paddle.max(origin_shape[:, 1])
  226. pred_result = paddle.zeros(
  227. [num_mask, max_h, max_w], dtype='int32') - 1
  228. id_start = 0
  229. for i in range(paddle.shape(bbox_num)[0]):
  230. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  231. mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :]
  232. im_h = origin_shape[i, 0]
  233. im_w = origin_shape[i, 1]
  234. bbox_num_i = bbox_num[id_start]
  235. pred_mask = self.paste_mask(mask_out_i[:, None, :, :],
  236. bboxes_i[:, 2:], im_h, im_w)
  237. pred_mask = paddle.cast(pred_mask >= self.binary_thresh,
  238. 'int32')
  239. pred_result[id_start:id_start + bbox_num[i], :im_h, :
  240. im_w] = pred_mask
  241. id_start += bbox_num[i]
  242. if self.assign_on_cpu:
  243. paddle.set_device(device)
  244. return pred_result
  245. @register
  246. class FCOSPostProcess(object):
  247. __inject__ = ['decode', 'nms']
  248. def __init__(self, decode=None, nms=None):
  249. super(FCOSPostProcess, self).__init__()
  250. self.decode = decode
  251. self.nms = nms
  252. def __call__(self, fcos_head_outs, scale_factor):
  253. """
  254. Decode the bbox and do NMS in FCOS.
  255. """
  256. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  257. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  258. centerness, scale_factor)
  259. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  260. return bbox_pred, bbox_num
  261. @register
  262. class S2ANetBBoxPostProcess(nn.Layer):
  263. __shared__ = ['num_classes']
  264. __inject__ = ['nms']
  265. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0,
  266. nms=None):
  267. super(S2ANetBBoxPostProcess, self).__init__()
  268. self.num_classes = num_classes
  269. self.nms_pre = nms_pre
  270. self.min_bbox_size = min_bbox_size
  271. self.nms = nms
  272. self.origin_shape_list = []
  273. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  274. np.array(
  275. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  276. dtype='float32'))
  277. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  278. def forward(self, pred_scores, pred_bboxes):
  279. """
  280. pred_scores : [N, M] score
  281. pred_bboxes : [N, 5] xc, yc, w, h, a
  282. im_shape : [N, 2] im_shape
  283. scale_factor : [N, 2] scale_factor
  284. """
  285. pred_ploys0 = rbox2poly(pred_bboxes)
  286. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  287. # pred_scores [NA, 16] --> [16, NA]
  288. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  289. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  290. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  291. self.num_classes)
  292. # Prevent empty bbox_pred from decode or NMS.
  293. # Bboxes and score before NMS may be empty due to the score threshold.
  294. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  295. 1] <= 1:
  296. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  297. bbox_num = self.fake_bbox_num
  298. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  299. return pred_cls_score_bbox, bbox_num
  300. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  301. """
  302. Rescale, clip and filter the bbox from the output of NMS to
  303. get final prediction.
  304. Args:
  305. bboxes(Tensor): bboxes [N, 10]
  306. bbox_num(Tensor): bbox_num
  307. im_shape(Tensor): [1 2]
  308. scale_factor(Tensor): [1 2]
  309. Returns:
  310. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  311. including labels, scores and bboxes. The size of
  312. bboxes are corresponding to the original image.
  313. """
  314. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  315. origin_shape_list = []
  316. scale_factor_list = []
  317. # scale_factor: scale_y, scale_x
  318. for i in range(bbox_num.shape[0]):
  319. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  320. [bbox_num[i], 2])
  321. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  322. scale = paddle.concat([
  323. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  324. scale_y
  325. ])
  326. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  327. origin_shape_list.append(expand_shape)
  328. scale_factor_list.append(expand_scale)
  329. origin_shape_list = paddle.concat(origin_shape_list)
  330. scale_factor_list = paddle.concat(scale_factor_list)
  331. # bboxes: [N, 10], label, score, bbox
  332. pred_label_score = bboxes[:, 0:2]
  333. pred_bbox = bboxes[:, 2:]
  334. # rescale bbox to original image
  335. pred_bbox = pred_bbox.reshape([-1, 8])
  336. scaled_bbox = pred_bbox / scale_factor_list
  337. origin_h = origin_shape_list[:, 0]
  338. origin_w = origin_shape_list[:, 1]
  339. bboxes = scaled_bbox
  340. zeros = paddle.zeros_like(origin_h)
  341. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  342. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  343. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  344. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  345. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  346. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  347. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  348. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  349. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  350. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  351. return pred_result
  352. @register
  353. class JDEBBoxPostProcess(nn.Layer):
  354. __shared__ = ['num_classes']
  355. __inject__ = ['decode', 'nms']
  356. def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
  357. super(JDEBBoxPostProcess, self).__init__()
  358. self.num_classes = num_classes
  359. self.decode = decode
  360. self.nms = nms
  361. self.return_idx = return_idx
  362. self.fake_bbox_pred = paddle.to_tensor(
  363. np.array(
  364. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  365. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  366. self.fake_nms_keep_idx = paddle.to_tensor(
  367. np.array(
  368. [[0]], dtype='int32'))
  369. self.fake_yolo_boxes_out = paddle.to_tensor(
  370. np.array(
  371. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  372. self.fake_yolo_scores_out = paddle.to_tensor(
  373. np.array(
  374. [[[0.0]]], dtype='float32'))
  375. self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  376. def forward(self, head_out, anchors):
  377. """
  378. Decode the bbox and do NMS for JDE model.
  379. Args:
  380. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  381. anchors (list): Anchors of JDE model.
  382. Returns:
  383. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  384. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  385. including labels, scores and bboxes.
  386. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  387. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  388. """
  389. boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
  390. if len(boxes_idx) == 0:
  391. boxes_idx = self.fake_boxes_idx
  392. yolo_boxes_out = self.fake_yolo_boxes_out
  393. yolo_scores_out = self.fake_yolo_scores_out
  394. else:
  395. yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
  396. # TODO: only support bs=1 now
  397. yolo_boxes_out = paddle.reshape(
  398. yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
  399. yolo_scores_out = paddle.reshape(
  400. yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
  401. boxes_idx = boxes_idx[:, 1:]
  402. if self.return_idx:
  403. bbox_pred, bbox_num, nms_keep_idx = self.nms(
  404. yolo_boxes_out, yolo_scores_out, self.num_classes)
  405. if bbox_pred.shape[0] == 0:
  406. bbox_pred = self.fake_bbox_pred
  407. bbox_num = self.fake_bbox_num
  408. nms_keep_idx = self.fake_nms_keep_idx
  409. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  410. else:
  411. bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
  412. self.num_classes)
  413. if bbox_pred.shape[0] == 0:
  414. bbox_pred = self.fake_bbox_pred
  415. bbox_num = self.fake_bbox_num
  416. return _, bbox_pred, bbox_num, _
  417. @register
  418. class CenterNetPostProcess(TTFBox):
  419. """
  420. Postprocess the model outputs to get final prediction:
  421. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  422. 2. Decode bboxes using center offset and box size.
  423. 3. Rescale decoded bboxes reference to the origin image shape.
  424. Args:
  425. max_per_img(int): the maximum number of predicted objects in a image,
  426. 500 by default.
  427. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  428. regress_ltrb (bool): whether to regress left/top/right/bottom or
  429. width/height for a box, true by default.
  430. for_mot (bool): whether return other features used in tracking model.
  431. """
  432. __shared__ = ['down_ratio', 'for_mot']
  433. def __init__(self,
  434. max_per_img=500,
  435. down_ratio=4,
  436. regress_ltrb=True,
  437. for_mot=False):
  438. super(TTFBox, self).__init__()
  439. self.max_per_img = max_per_img
  440. self.down_ratio = down_ratio
  441. self.regress_ltrb = regress_ltrb
  442. self.for_mot = for_mot
  443. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  444. heat = self._simple_nms(hm)
  445. scores, inds, topk_clses, ys, xs = self._topk(heat)
  446. scores = scores.unsqueeze(1)
  447. clses = topk_clses.unsqueeze(1)
  448. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  449. # Like TTFBox, batch size is 1.
  450. # TODO: support batch size > 1
  451. reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
  452. reg = paddle.gather(reg, inds)
  453. xs = paddle.cast(xs, 'float32')
  454. ys = paddle.cast(ys, 'float32')
  455. xs = xs + reg[:, 0:1]
  456. ys = ys + reg[:, 1:2]
  457. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  458. wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
  459. wh = paddle.gather(wh, inds)
  460. if self.regress_ltrb:
  461. x1 = xs - wh[:, 0:1]
  462. y1 = ys - wh[:, 1:2]
  463. x2 = xs + wh[:, 2:3]
  464. y2 = ys + wh[:, 3:4]
  465. else:
  466. x1 = xs - wh[:, 0:1] / 2
  467. y1 = ys - wh[:, 1:2] / 2
  468. x2 = xs + wh[:, 0:1] / 2
  469. y2 = ys + wh[:, 1:2] / 2
  470. n, c, feat_h, feat_w = paddle.shape(hm)
  471. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  472. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  473. x1 = x1 * self.down_ratio
  474. y1 = y1 * self.down_ratio
  475. x2 = x2 * self.down_ratio
  476. y2 = y2 * self.down_ratio
  477. x1 = x1 - padw
  478. y1 = y1 - padh
  479. x2 = x2 - padw
  480. y2 = y2 - padh
  481. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  482. scale_y = scale_factor[:, 0:1]
  483. scale_x = scale_factor[:, 1:2]
  484. scale_expand = paddle.concat(
  485. [scale_x, scale_y, scale_x, scale_y], axis=1)
  486. boxes_shape = bboxes.shape[:]
  487. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  488. bboxes = paddle.divide(bboxes, scale_expand)
  489. results = paddle.concat([clses, scores, bboxes], axis=1)
  490. if self.for_mot:
  491. return results, inds, topk_clses
  492. else:
  493. return results, paddle.shape(results)[0:1], topk_clses
  494. @register
  495. class DETRBBoxPostProcess(object):
  496. __shared__ = ['num_classes', 'use_focal_loss']
  497. __inject__ = []
  498. def __init__(self,
  499. num_classes=80,
  500. num_top_queries=100,
  501. use_focal_loss=False):
  502. super(DETRBBoxPostProcess, self).__init__()
  503. self.num_classes = num_classes
  504. self.num_top_queries = num_top_queries
  505. self.use_focal_loss = use_focal_loss
  506. def __call__(self, head_out, im_shape, scale_factor):
  507. """
  508. Decode the bbox.
  509. Args:
  510. head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
  511. im_shape (Tensor): The shape of the input image.
  512. scale_factor (Tensor): The scale factor of the input image.
  513. Returns:
  514. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  515. labels, scores and bboxes. The size of bboxes are corresponding
  516. to the input image, the bboxes may be used in other branch.
  517. bbox_num (Tensor): The number of prediction boxes of each batch with
  518. shape [bs], and is N.
  519. """
  520. bboxes, logits, masks = head_out
  521. bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
  522. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  523. img_h, img_w = origin_shape.unbind(1)
  524. origin_shape = paddle.stack(
  525. [img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0)
  526. bbox_pred *= origin_shape
  527. scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
  528. logits)[:, :, :-1]
  529. if not self.use_focal_loss:
  530. scores, labels = scores.max(-1), scores.argmax(-1)
  531. if scores.shape[1] > self.num_top_queries:
  532. scores, index = paddle.topk(
  533. scores, self.num_top_queries, axis=-1)
  534. labels = paddle.stack(
  535. [paddle.gather(l, i) for l, i in zip(labels, index)])
  536. bbox_pred = paddle.stack(
  537. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  538. else:
  539. scores, index = paddle.topk(
  540. scores.reshape([logits.shape[0], -1]),
  541. self.num_top_queries,
  542. axis=-1)
  543. labels = index % logits.shape[2]
  544. index = index // logits.shape[2]
  545. bbox_pred = paddle.stack(
  546. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  547. bbox_pred = paddle.concat(
  548. [
  549. labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
  550. bbox_pred
  551. ],
  552. axis=-1)
  553. bbox_num = paddle.to_tensor(
  554. bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
  555. bbox_pred = bbox_pred.reshape([-1, 6])
  556. return bbox_pred, bbox_num
  557. @register
  558. class SparsePostProcess(object):
  559. __shared__ = ['num_classes']
  560. def __init__(self, num_proposals, num_classes=80):
  561. super(SparsePostProcess, self).__init__()
  562. self.num_classes = num_classes
  563. self.num_proposals = num_proposals
  564. def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
  565. """
  566. Arguments:
  567. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  568. The tensor predicts the classification probability for each proposal.
  569. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  570. The tensor predicts 4-vector (x,y,w,h) box
  571. regression values for every proposal
  572. scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
  573. img_whwh (Tensor): tensors of shape [batch_size, 4]
  574. Returns:
  575. bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
  576. [label, confidence, xmin, ymin, xmax, ymax]
  577. bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
  578. """
  579. assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
  580. img_wh = img_whwh[:, :2]
  581. scores = F.sigmoid(box_cls)
  582. labels = paddle.arange(0, self.num_classes). \
  583. unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
  584. classes_all = []
  585. scores_all = []
  586. boxes_all = []
  587. for i, (scores_per_image,
  588. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  589. scores_per_image, topk_indices = scores_per_image.flatten(
  590. 0, 1).topk(
  591. self.num_proposals, sorted=False)
  592. labels_per_image = paddle.gather(labels, topk_indices, axis=0)
  593. box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
  594. [1, self.num_classes, 1]).reshape([-1, 4])
  595. box_pred_per_image = paddle.gather(
  596. box_pred_per_image, topk_indices, axis=0)
  597. classes_all.append(labels_per_image)
  598. scores_all.append(scores_per_image)
  599. boxes_all.append(box_pred_per_image)
  600. bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
  601. boxes_final = []
  602. for i in range(len(scale_factor_wh)):
  603. classes = classes_all[i]
  604. boxes = boxes_all[i]
  605. scores = scores_all[i]
  606. boxes[:, 0::2] = paddle.clip(
  607. boxes[:, 0::2], min=0,
  608. max=img_wh[i][0]) / scale_factor_wh[i][0]
  609. boxes[:, 1::2] = paddle.clip(
  610. boxes[:, 1::2], min=0,
  611. max=img_wh[i][1]) / scale_factor_wh[i][1]
  612. boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
  613. boxes[:, 3] - boxes[:, 1]).numpy()
  614. keep = (boxes_w > 1.) & (boxes_h > 1.)
  615. if (keep.sum() == 0):
  616. bboxes = paddle.zeros([1, 6]).astype("float32")
  617. else:
  618. boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
  619. classes = paddle.to_tensor(classes.numpy()[keep]).astype(
  620. "float32").unsqueeze(-1)
  621. scores = paddle.to_tensor(scores.numpy()[keep]).astype(
  622. "float32").unsqueeze(-1)
  623. bboxes = paddle.concat([classes, scores, boxes], axis=-1)
  624. boxes_final.append(bboxes)
  625. bbox_num[i] = bboxes.shape[0]
  626. bbox_pred = paddle.concat(boxes_final)
  627. return bbox_pred, bbox_num
  628. def nms(dets, thresh):
  629. """Apply classic DPM-style greedy NMS."""
  630. if dets.shape[0] == 0:
  631. return dets[[], :]
  632. scores = dets[:, 0]
  633. x1 = dets[:, 1]
  634. y1 = dets[:, 2]
  635. x2 = dets[:, 3]
  636. y2 = dets[:, 4]
  637. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  638. order = scores.argsort()[::-1]
  639. ndets = dets.shape[0]
  640. suppressed = np.zeros((ndets), dtype=np.int)
  641. # nominal indices
  642. # _i, _j
  643. # sorted indices
  644. # i, j
  645. # temp variables for box i's (the box currently under consideration)
  646. # ix1, iy1, ix2, iy2, iarea
  647. # variables for computing overlap with box j (lower scoring box)
  648. # xx1, yy1, xx2, yy2
  649. # w, h
  650. # inter, ovr
  651. for _i in range(ndets):
  652. i = order[_i]
  653. if suppressed[i] == 1:
  654. continue
  655. ix1 = x1[i]
  656. iy1 = y1[i]
  657. ix2 = x2[i]
  658. iy2 = y2[i]
  659. iarea = areas[i]
  660. for _j in range(_i + 1, ndets):
  661. j = order[_j]
  662. if suppressed[j] == 1:
  663. continue
  664. xx1 = max(ix1, x1[j])
  665. yy1 = max(iy1, y1[j])
  666. xx2 = min(ix2, x2[j])
  667. yy2 = min(iy2, y2[j])
  668. w = max(0.0, xx2 - xx1 + 1)
  669. h = max(0.0, yy2 - yy1 + 1)
  670. inter = w * h
  671. ovr = inter / (iarea + areas[j] - inter)
  672. if ovr >= thresh:
  673. suppressed[j] = 1
  674. keep = np.where(suppressed == 0)[0]
  675. dets = dets[keep, :]
  676. return dets