post_process.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register
  19. from paddlex.ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly
  20. from paddlex.ppdet.modeling.layers import TTFBox
  21. try:
  22. from collections.abc import Sequence
  23. except Exception:
  24. from collections import Sequence
  25. __all__ = [
  26. 'BBoxPostProcess',
  27. 'MaskPostProcess',
  28. 'FCOSPostProcess',
  29. 'S2ANetBBoxPostProcess',
  30. 'JDEBBoxPostProcess',
  31. 'CenterNetPostProcess',
  32. ]
  33. @register
  34. class BBoxPostProcess(object):
  35. __shared__ = ['num_classes']
  36. __inject__ = ['decode', 'nms']
  37. def __init__(self, num_classes=80, decode=None, nms=None):
  38. super(BBoxPostProcess, self).__init__()
  39. self.num_classes = num_classes
  40. self.decode = decode
  41. self.nms = nms
  42. def __call__(self, head_out, rois, im_shape, scale_factor):
  43. """
  44. Decode the bbox and do NMS if needed.
  45. Args:
  46. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  47. rois (tuple): roi and rois_num of rpn_head output.
  48. im_shape (Tensor): The shape of the input image.
  49. scale_factor (Tensor): The scale factor of the input image.
  50. Returns:
  51. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  52. labels, scores and bboxes. The size of bboxes are corresponding
  53. to the input image, the bboxes may be used in other branch.
  54. bbox_num (Tensor): The number of prediction boxes of each batch with
  55. shape [1], and is N.
  56. """
  57. if self.nms is not None:
  58. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  59. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  60. else:
  61. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  62. scale_factor)
  63. return bbox_pred, bbox_num
  64. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  65. """
  66. Rescale, clip and filter the bbox from the output of NMS to
  67. get final prediction.
  68. Notes:
  69. Currently only support bs = 1.
  70. Args:
  71. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  72. and NMS, including labels, scores and bboxes.
  73. bbox_num (Tensor): The number of prediction boxes of each batch with
  74. shape [1], and is N.
  75. im_shape (Tensor): The shape of the input image.
  76. scale_factor (Tensor): The scale factor of the input image.
  77. Returns:
  78. pred_result (Tensor): The final prediction results with shape [N, 6]
  79. including labels, scores and bboxes.
  80. """
  81. if bboxes.shape[0] == 0:
  82. bboxes = paddle.to_tensor(
  83. np.array(
  84. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  85. bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  86. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  87. origin_shape_list = []
  88. scale_factor_list = []
  89. # scale_factor: scale_y, scale_x
  90. for i in range(bbox_num.shape[0]):
  91. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  92. [bbox_num[i], 2])
  93. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  94. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  95. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  96. origin_shape_list.append(expand_shape)
  97. scale_factor_list.append(expand_scale)
  98. self.origin_shape_list = paddle.concat(origin_shape_list)
  99. scale_factor_list = paddle.concat(scale_factor_list)
  100. # bboxes: [N, 6], label, score, bbox
  101. pred_label = bboxes[:, 0:1]
  102. pred_score = bboxes[:, 1:2]
  103. pred_bbox = bboxes[:, 2:]
  104. # rescale bbox to original image
  105. scaled_bbox = pred_bbox / scale_factor_list
  106. origin_h = self.origin_shape_list[:, 0]
  107. origin_w = self.origin_shape_list[:, 1]
  108. zeros = paddle.zeros_like(origin_h)
  109. # clip bbox to [0, original_size]
  110. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  111. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  112. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  113. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  114. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  115. # filter empty bbox
  116. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  117. keep_mask = paddle.unsqueeze(keep_mask, [1])
  118. pred_label = paddle.where(keep_mask, pred_label,
  119. paddle.ones_like(pred_label) * -1)
  120. pred_result = paddle.concat(
  121. [pred_label, pred_score, pred_bbox], axis=1)
  122. return pred_result
  123. def get_origin_shape(self, ):
  124. return self.origin_shape_list
  125. @register
  126. class MaskPostProcess(object):
  127. def __init__(self, binary_thresh=0.5):
  128. super(MaskPostProcess, self).__init__()
  129. self.binary_thresh = binary_thresh
  130. def paste_mask(self, masks, boxes, im_h, im_w):
  131. """
  132. Paste the mask prediction to the original image.
  133. """
  134. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  135. masks = paddle.unsqueeze(masks, [0, 1])
  136. img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
  137. img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
  138. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  139. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  140. img_x = paddle.unsqueeze(img_x, [1])
  141. img_y = paddle.unsqueeze(img_y, [2])
  142. N = boxes.shape[0]
  143. gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
  144. gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
  145. grid = paddle.stack([gx, gy], axis=3)
  146. img_masks = F.grid_sample(masks, grid, align_corners=False)
  147. return img_masks[:, 0]
  148. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  149. """
  150. Decode the mask_out and paste the mask to the origin image.
  151. Args:
  152. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  153. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  154. and NMS, including labels, scores and bboxes.
  155. bbox_num (Tensor): The number of prediction boxes of each batch with
  156. shape [1], and is N.
  157. origin_shape (Tensor): The origin shape of the input image, the tensor
  158. shape is [N, 2], and each row is [h, w].
  159. Returns:
  160. pred_result (Tensor): The final prediction mask results with shape
  161. [N, h, w] in binary mask style.
  162. """
  163. num_mask = mask_out.shape[0]
  164. origin_shape = paddle.cast(origin_shape, 'int32')
  165. # TODO: support bs > 1 and mask output dtype is bool
  166. pred_result = paddle.zeros(
  167. [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
  168. if bbox_num == 1 and bboxes[0][0] == -1:
  169. return pred_result
  170. # TODO: optimize chunk paste
  171. pred_result = []
  172. for i in range(bboxes.shape[0]):
  173. im_h, im_w = origin_shape[i][0], origin_shape[i][1]
  174. pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
  175. im_w)
  176. pred_mask = pred_mask >= self.binary_thresh
  177. pred_mask = paddle.cast(pred_mask, 'int32')
  178. pred_result.append(pred_mask)
  179. pred_result = paddle.concat(pred_result)
  180. return pred_result
  181. @register
  182. class FCOSPostProcess(object):
  183. __inject__ = ['decode', 'nms']
  184. def __init__(self, decode=None, nms=None):
  185. super(FCOSPostProcess, self).__init__()
  186. self.decode = decode
  187. self.nms = nms
  188. def __call__(self, fcos_head_outs, scale_factor):
  189. """
  190. Decode the bbox and do NMS in FCOS.
  191. """
  192. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  193. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  194. centerness, scale_factor)
  195. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  196. return bbox_pred, bbox_num
  197. @register
  198. class S2ANetBBoxPostProcess(nn.Layer):
  199. __shared__ = ['num_classes']
  200. __inject__ = ['nms']
  201. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0,
  202. nms=None):
  203. super(S2ANetBBoxPostProcess, self).__init__()
  204. self.num_classes = num_classes
  205. self.nms_pre = nms_pre
  206. self.min_bbox_size = min_bbox_size
  207. self.nms = nms
  208. self.origin_shape_list = []
  209. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  210. np.array(
  211. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  212. dtype='float32'))
  213. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  214. def forward(self, pred_scores, pred_bboxes):
  215. """
  216. pred_scores : [N, M] score
  217. pred_bboxes : [N, 5] xc, yc, w, h, a
  218. im_shape : [N, 2] im_shape
  219. scale_factor : [N, 2] scale_factor
  220. """
  221. pred_ploys0 = rbox2poly(pred_bboxes)
  222. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  223. # pred_scores [NA, 16] --> [16, NA]
  224. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  225. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  226. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  227. self.num_classes)
  228. # Prevent empty bbox_pred from decode or NMS.
  229. # Bboxes and score before NMS may be empty due to the score threshold.
  230. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  231. 1] <= 1:
  232. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  233. bbox_num = self.fake_bbox_num
  234. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  235. return pred_cls_score_bbox, bbox_num
  236. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  237. """
  238. Rescale, clip and filter the bbox from the output of NMS to
  239. get final prediction.
  240. Args:
  241. bboxes(Tensor): bboxes [N, 10]
  242. bbox_num(Tensor): bbox_num
  243. im_shape(Tensor): [1 2]
  244. scale_factor(Tensor): [1 2]
  245. Returns:
  246. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  247. including labels, scores and bboxes. The size of
  248. bboxes are corresponding to the original image.
  249. """
  250. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  251. origin_shape_list = []
  252. scale_factor_list = []
  253. # scale_factor: scale_y, scale_x
  254. for i in range(bbox_num.shape[0]):
  255. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  256. [bbox_num[i], 2])
  257. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  258. scale = paddle.concat([
  259. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  260. scale_y
  261. ])
  262. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  263. origin_shape_list.append(expand_shape)
  264. scale_factor_list.append(expand_scale)
  265. origin_shape_list = paddle.concat(origin_shape_list)
  266. scale_factor_list = paddle.concat(scale_factor_list)
  267. # bboxes: [N, 10], label, score, bbox
  268. pred_label_score = bboxes[:, 0:2]
  269. pred_bbox = bboxes[:, 2:]
  270. # rescale bbox to original image
  271. pred_bbox = pred_bbox.reshape([-1, 8])
  272. scaled_bbox = pred_bbox / scale_factor_list
  273. origin_h = origin_shape_list[:, 0]
  274. origin_w = origin_shape_list[:, 1]
  275. bboxes = scaled_bbox
  276. zeros = paddle.zeros_like(origin_h)
  277. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  278. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  279. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  280. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  281. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  282. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  283. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  284. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  285. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  286. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  287. return pred_result
  288. @register
  289. class JDEBBoxPostProcess(BBoxPostProcess):
  290. def __call__(self, head_out, anchors):
  291. """
  292. Decode the bbox and do NMS for JDE model.
  293. Args:
  294. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  295. anchors (list): Anchors of JDE model.
  296. Returns:
  297. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  298. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  299. including labels, scores and bboxes.
  300. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  301. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  302. """
  303. boxes_idx, bboxes, score = self.decode(head_out, anchors)
  304. bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, score,
  305. self.num_classes)
  306. if bbox_pred.shape[0] == 0:
  307. bbox_pred = paddle.to_tensor(
  308. np.array(
  309. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  310. bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  311. nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32'))
  312. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  313. @register
  314. class CenterNetPostProcess(TTFBox):
  315. """
  316. Postprocess the model outputs to get final prediction:
  317. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  318. 2. Decode bboxes using center offset and box size.
  319. 3. Rescale decoded bboxes reference to the origin image shape.
  320. Args:
  321. max_per_img(int): the maximum number of predicted objects in a image,
  322. 500 by default.
  323. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  324. regress_ltrb (bool): whether to regress left/top/right/bottom or
  325. width/height for a box, true by default.
  326. for_mot (bool): whether return other features used in tracking model.
  327. """
  328. __shared__ = ['down_ratio']
  329. def __init__(self,
  330. max_per_img=500,
  331. down_ratio=4,
  332. regress_ltrb=True,
  333. for_mot=False):
  334. super(TTFBox, self).__init__()
  335. self.max_per_img = max_per_img
  336. self.down_ratio = down_ratio
  337. self.regress_ltrb = regress_ltrb
  338. self.for_mot = for_mot
  339. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  340. heat = self._simple_nms(hm)
  341. scores, inds, clses, ys, xs = self._topk(heat)
  342. scores = paddle.tensor.unsqueeze(scores, [1])
  343. clses = paddle.tensor.unsqueeze(clses, [1])
  344. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  345. # Like TTFBox, batch size is 1.
  346. # TODO: support batch size > 1
  347. reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]])
  348. reg = paddle.gather(reg, inds)
  349. xs = paddle.cast(xs, 'float32')
  350. ys = paddle.cast(ys, 'float32')
  351. xs = xs + reg[:, 0:1]
  352. ys = ys + reg[:, 1:2]
  353. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  354. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  355. wh = paddle.gather(wh, inds)
  356. if self.regress_ltrb:
  357. x1 = xs - wh[:, 0:1]
  358. y1 = ys - wh[:, 1:2]
  359. x2 = xs + wh[:, 2:3]
  360. y2 = ys + wh[:, 3:4]
  361. else:
  362. x1 = xs - wh[:, 0:1] / 2
  363. y1 = ys - wh[:, 1:2] / 2
  364. x2 = xs + wh[:, 0:1] / 2
  365. y2 = ys + wh[:, 1:2] / 2
  366. n, c, feat_h, feat_w = paddle.shape(hm)
  367. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  368. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  369. x1 = x1 * self.down_ratio
  370. y1 = y1 * self.down_ratio
  371. x2 = x2 * self.down_ratio
  372. y2 = y2 * self.down_ratio
  373. x1 = x1 - padw
  374. y1 = y1 - padh
  375. x2 = x2 - padw
  376. y2 = y2 - padh
  377. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  378. scale_y = scale_factor[:, 0:1]
  379. scale_x = scale_factor[:, 1:2]
  380. scale_expand = paddle.concat(
  381. [scale_x, scale_y, scale_x, scale_y], axis=1)
  382. boxes_shape = paddle.shape(bboxes)
  383. boxes_shape.stop_gradient = True
  384. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  385. bboxes = paddle.divide(bboxes, scale_expand)
  386. if self.for_mot:
  387. results = paddle.concat([bboxes, scores, clses], axis=1)
  388. return results, inds
  389. else:
  390. results = paddle.concat([clses, scores, bboxes], axis=1)
  391. return results, paddle.shape(results)[0:1]