bbox_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn.functional as F
  17. import math
  18. import numpy as np
  19. def bbox2delta(src_boxes, tgt_boxes, weights):
  20. src_w = src_boxes[:, 2] - src_boxes[:, 0]
  21. src_h = src_boxes[:, 3] - src_boxes[:, 1]
  22. src_ctr_x = src_boxes[:, 0] + 0.5 * src_w
  23. src_ctr_y = src_boxes[:, 1] + 0.5 * src_h
  24. tgt_w = tgt_boxes[:, 2] - tgt_boxes[:, 0]
  25. tgt_h = tgt_boxes[:, 3] - tgt_boxes[:, 1]
  26. tgt_ctr_x = tgt_boxes[:, 0] + 0.5 * tgt_w
  27. tgt_ctr_y = tgt_boxes[:, 1] + 0.5 * tgt_h
  28. wx, wy, ww, wh = weights
  29. dx = wx * (tgt_ctr_x - src_ctr_x) / src_w
  30. dy = wy * (tgt_ctr_y - src_ctr_y) / src_h
  31. dw = ww * paddle.log(tgt_w / src_w)
  32. dh = wh * paddle.log(tgt_h / src_h)
  33. deltas = paddle.stack((dx, dy, dw, dh), axis=1)
  34. return deltas
  35. def delta2bbox(deltas, boxes, weights):
  36. clip_scale = math.log(1000.0 / 16)
  37. widths = boxes[:, 2] - boxes[:, 0]
  38. heights = boxes[:, 3] - boxes[:, 1]
  39. ctr_x = boxes[:, 0] + 0.5 * widths
  40. ctr_y = boxes[:, 1] + 0.5 * heights
  41. wx, wy, ww, wh = weights
  42. dx = deltas[:, 0::4] / wx
  43. dy = deltas[:, 1::4] / wy
  44. dw = deltas[:, 2::4] / ww
  45. dh = deltas[:, 3::4] / wh
  46. # Prevent sending too large values into paddle.exp()
  47. dw = paddle.clip(dw, max=clip_scale)
  48. dh = paddle.clip(dh, max=clip_scale)
  49. pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1)
  50. pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1)
  51. pred_w = paddle.exp(dw) * widths.unsqueeze(1)
  52. pred_h = paddle.exp(dh) * heights.unsqueeze(1)
  53. pred_boxes = []
  54. pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
  55. pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
  56. pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
  57. pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
  58. pred_boxes = paddle.stack(pred_boxes, axis=-1)
  59. return pred_boxes
  60. def expand_bbox(bboxes, scale):
  61. w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
  62. h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
  63. x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
  64. y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
  65. w_half *= scale
  66. h_half *= scale
  67. bboxes_exp = np.zeros(bboxes.shape, dtype=np.float32)
  68. bboxes_exp[:, 0] = x_c - w_half
  69. bboxes_exp[:, 2] = x_c + w_half
  70. bboxes_exp[:, 1] = y_c - h_half
  71. bboxes_exp[:, 3] = y_c + h_half
  72. return bboxes_exp
  73. def clip_bbox(boxes, im_shape):
  74. h, w = im_shape[0], im_shape[1]
  75. x1 = boxes[:, 0].clip(0, w)
  76. y1 = boxes[:, 1].clip(0, h)
  77. x2 = boxes[:, 2].clip(0, w)
  78. y2 = boxes[:, 3].clip(0, h)
  79. return paddle.stack([x1, y1, x2, y2], axis=1)
  80. def nonempty_bbox(boxes, min_size=0, return_mask=False):
  81. w = boxes[:, 2] - boxes[:, 0]
  82. h = boxes[:, 3] - boxes[:, 1]
  83. mask = paddle.logical_and(w > min_size, w > min_size)
  84. if return_mask:
  85. return mask
  86. keep = paddle.nonzero(mask).flatten()
  87. return keep
  88. def bbox_area(boxes):
  89. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  90. def bbox_overlaps(boxes1, boxes2):
  91. """
  92. Calculate overlaps between boxes1 and boxes2
  93. Args:
  94. boxes1 (Tensor): boxes with shape [M, 4]
  95. boxes2 (Tensor): boxes with shape [N, 4]
  96. Return:
  97. overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
  98. """
  99. M = boxes1.shape[0]
  100. N = boxes2.shape[0]
  101. if M * N == 0:
  102. return paddle.zeros([M, N], dtype='float32')
  103. area1 = bbox_area(boxes1)
  104. area2 = bbox_area(boxes2)
  105. xy_max = paddle.minimum(
  106. paddle.unsqueeze(boxes1, 1)[:, :, 2:], boxes2[:, 2:])
  107. xy_min = paddle.maximum(
  108. paddle.unsqueeze(boxes1, 1)[:, :, :2], boxes2[:, :2])
  109. width_height = xy_max - xy_min
  110. width_height = width_height.clip(min=0)
  111. inter = width_height.prod(axis=2)
  112. overlaps = paddle.where(inter > 0, inter /
  113. (paddle.unsqueeze(area1, 1) + area2 - inter),
  114. paddle.zeros_like(inter))
  115. return overlaps
  116. def xywh2xyxy(box):
  117. x, y, w, h = box
  118. x1 = x - w * 0.5
  119. y1 = y - h * 0.5
  120. x2 = x + w * 0.5
  121. y2 = y + h * 0.5
  122. return [x1, y1, x2, y2]
  123. def make_grid(h, w, dtype):
  124. yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
  125. return paddle.stack((xv, yv), 2).cast(dtype=dtype)
  126. def decode_yolo(box, anchor, downsample_ratio):
  127. """decode yolo box
  128. Args:
  129. box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  130. anchor (list): anchor with the shape [na, 2]
  131. downsample_ratio (int): downsample ratio, default 32
  132. scale (float): scale, default 1.
  133. Return:
  134. box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
  135. """
  136. x, y, w, h = box
  137. na, grid_h, grid_w = x.shape[1:4]
  138. grid = make_grid(grid_h, grid_w, x.dtype).reshape(
  139. (1, 1, grid_h, grid_w, 2))
  140. x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
  141. y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
  142. anchor = paddle.to_tensor(anchor)
  143. anchor = paddle.cast(anchor, x.dtype)
  144. anchor = anchor.reshape((1, na, 1, 1, 2))
  145. w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
  146. h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
  147. return [x1, y1, w1, h1]
  148. def iou_similarity(box1, box2, eps=1e-9):
  149. """Calculate iou of box1 and box2
  150. Args:
  151. box1 (Tensor): box with the shape [N, M1, 4]
  152. box2 (Tensor): box with the shape [N, M2, 4]
  153. Return:
  154. iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
  155. """
  156. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  157. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  158. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  159. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  160. x1y1 = paddle.maximum(px1y1, gx1y1)
  161. x2y2 = paddle.minimum(px2y2, gx2y2)
  162. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  163. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  164. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  165. union = area1 + area2 - overlap + eps
  166. return overlap / union
  167. def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
  168. """calculate the iou of box1 and box2
  169. Args:
  170. box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  171. box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  172. giou (bool): whether use giou or not, default False
  173. diou (bool): whether use diou or not, default False
  174. ciou (bool): whether use ciou or not, default False
  175. eps (float): epsilon to avoid divide by zero
  176. Return:
  177. iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
  178. """
  179. px1, py1, px2, py2 = box1
  180. gx1, gy1, gx2, gy2 = box2
  181. x1 = paddle.maximum(px1, gx1)
  182. y1 = paddle.maximum(py1, gy1)
  183. x2 = paddle.minimum(px2, gx2)
  184. y2 = paddle.minimum(py2, gy2)
  185. overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
  186. area1 = (px2 - px1) * (py2 - py1)
  187. area1 = area1.clip(0)
  188. area2 = (gx2 - gx1) * (gy2 - gy1)
  189. area2 = area2.clip(0)
  190. union = area1 + area2 - overlap + eps
  191. iou = overlap / union
  192. if giou or ciou or diou:
  193. # convex w, h
  194. cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
  195. ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
  196. if giou:
  197. c_area = cw * ch + eps
  198. return iou - (c_area - union) / c_area
  199. else:
  200. # convex diagonal squared
  201. c2 = cw**2 + ch**2 + eps
  202. # center distance
  203. rho2 = (
  204. (px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
  205. if diou:
  206. return iou - rho2 / c2
  207. else:
  208. w1, h1 = px2 - px1, py2 - py1 + eps
  209. w2, h2 = gx2 - gx1, gy2 - gy1 + eps
  210. delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
  211. v = (4 / math.pi**2) * paddle.pow(delta, 2)
  212. alpha = v / (1 + eps - iou + v)
  213. alpha.stop_gradient = True
  214. return iou - (rho2 / c2 + v * alpha)
  215. else:
  216. return iou
  217. def poly2rbox(polys):
  218. """
  219. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  220. to
  221. rotated_boxes:[x_ctr,y_ctr,w,h,angle]
  222. """
  223. rotated_boxes = []
  224. for poly in polys:
  225. poly = np.array(poly[:8], dtype=np.float32)
  226. pt1 = (poly[0], poly[1])
  227. pt2 = (poly[2], poly[3])
  228. pt3 = (poly[4], poly[5])
  229. pt4 = (poly[6], poly[7])
  230. edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[
  231. 1]) * (pt1[1] - pt2[1]))
  232. edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[
  233. 1]) * (pt2[1] - pt3[1]))
  234. width = max(edge1, edge2)
  235. height = min(edge1, edge2)
  236. rbox_angle = 0
  237. if edge1 > edge2:
  238. rbox_angle = np.arctan2(
  239. np.float(pt2[1] - pt1[1]), np.float(pt2[0] - pt1[0]))
  240. elif edge2 >= edge1:
  241. rbox_angle = np.arctan2(
  242. np.float(pt4[1] - pt1[1]), np.float(pt4[0] - pt1[0]))
  243. def norm_angle(angle, range=[-np.pi / 4, np.pi]):
  244. return (angle - range[0]) % range[1] + range[0]
  245. rbox_angle = norm_angle(rbox_angle)
  246. x_ctr = np.float(pt1[0] + pt3[0]) / 2
  247. y_ctr = np.float(pt1[1] + pt3[1]) / 2
  248. rotated_box = np.array([x_ctr, y_ctr, width, height, rbox_angle])
  249. rotated_boxes.append(rotated_box)
  250. ret_rotated_boxes = np.array(rotated_boxes)
  251. assert ret_rotated_boxes.shape[1] == 5
  252. return ret_rotated_boxes
  253. def cal_line_length(point1, point2):
  254. import math
  255. return math.sqrt(
  256. math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1],
  257. 2))
  258. def get_best_begin_point_single(coordinate):
  259. x1, y1, x2, y2, x3, y3, x4, y4 = coordinate
  260. xmin = min(x1, x2, x3, x4)
  261. ymin = min(y1, y2, y3, y4)
  262. xmax = max(x1, x2, x3, x4)
  263. ymax = max(y1, y2, y3, y4)
  264. combinate = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
  265. [[x4, y4], [x1, y1], [x2, y2], [x3, y3]],
  266. [[x3, y3], [x4, y4], [x1, y1], [x2, y2]],
  267. [[x2, y2], [x3, y3], [x4, y4], [x1, y1]]]
  268. dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
  269. force = 100000000.0
  270. force_flag = 0
  271. for i in range(4):
  272. temp_force = cal_line_length(combinate[i][0], dst_coordinate[0]) \
  273. + cal_line_length(combinate[i][1], dst_coordinate[1]) \
  274. + cal_line_length(combinate[i][2], dst_coordinate[2]) \
  275. + cal_line_length(combinate[i][3], dst_coordinate[3])
  276. if temp_force < force:
  277. force = temp_force
  278. force_flag = i
  279. if force_flag != 0:
  280. pass
  281. return np.array(combinate[force_flag]).reshape(8)
  282. def rbox2poly_np(rrects):
  283. """
  284. rrect:[x_ctr,y_ctr,w,h,angle]
  285. to
  286. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  287. """
  288. polys = []
  289. for i in range(rrects.shape[0]):
  290. rrect = rrects[i]
  291. # x_ctr, y_ctr, width, height, angle = rrect[:5]
  292. x_ctr = rrect[0]
  293. y_ctr = rrect[1]
  294. width = rrect[2]
  295. height = rrect[3]
  296. angle = rrect[4]
  297. tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
  298. rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
  299. R = np.array([[np.cos(angle), -np.sin(angle)],
  300. [np.sin(angle), np.cos(angle)]])
  301. poly = R.dot(rect)
  302. x0, x1, x2, x3 = poly[0, :4] + x_ctr
  303. y0, y1, y2, y3 = poly[1, :4] + y_ctr
  304. poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
  305. poly = get_best_begin_point_single(poly)
  306. polys.append(poly)
  307. polys = np.array(polys)
  308. return polys
  309. def rbox2poly(rrects):
  310. """
  311. rrect:[x_ctr,y_ctr,w,h,angle]
  312. to
  313. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  314. """
  315. N = paddle.shape(rrects)[0]
  316. x_ctr = rrects[:, 0]
  317. y_ctr = rrects[:, 1]
  318. width = rrects[:, 2]
  319. height = rrects[:, 3]
  320. angle = rrects[:, 4]
  321. tl_x, tl_y, br_x, br_y = -width * 0.5, -height * 0.5, width * 0.5, height * 0.5
  322. normal_rects = paddle.stack(
  323. [tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], axis=0)
  324. normal_rects = paddle.reshape(normal_rects, [2, 4, N])
  325. normal_rects = paddle.transpose(normal_rects, [2, 0, 1])
  326. sin, cos = paddle.sin(angle), paddle.cos(angle)
  327. # M.shape=[N,2,2]
  328. M = paddle.stack([cos, -sin, sin, cos], axis=0)
  329. M = paddle.reshape(M, [2, 2, N])
  330. M = paddle.transpose(M, [2, 0, 1])
  331. # polys:[N,8]
  332. polys = paddle.matmul(M, normal_rects)
  333. polys = paddle.transpose(polys, [2, 1, 0])
  334. polys = paddle.reshape(polys, [-1, N])
  335. polys = paddle.transpose(polys, [1, 0])
  336. tmp = paddle.stack(
  337. [x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr], axis=1)
  338. polys = polys + tmp
  339. return polys
  340. def bbox_iou_np_expand(box1, box2, x1y1x2y2=True, eps=1e-16):
  341. """
  342. Calculate the iou of box1 and box2 with numpy.
  343. Args:
  344. box1 (ndarray): [N, 4]
  345. box2 (ndarray): [M, 4], usually N != M
  346. x1y1x2y2 (bool): whether in x1y1x2y2 stype, default True
  347. eps (float): epsilon to avoid divide by zero
  348. Return:
  349. iou (ndarray): iou of box1 and box2, [N, M]
  350. """
  351. N, M = len(box1), len(box2) # usually N != M
  352. if x1y1x2y2:
  353. b1_x1, b1_y1 = box1[:, 0], box1[:, 1]
  354. b1_x2, b1_y2 = box1[:, 2], box1[:, 3]
  355. b2_x1, b2_y1 = box2[:, 0], box2[:, 1]
  356. b2_x2, b2_y2 = box2[:, 2], box2[:, 3]
  357. else:
  358. # cxcywh style
  359. # Transform from center and width to exact coordinates
  360. b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
  361. b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
  362. b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
  363. b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
  364. # get the coordinates of the intersection rectangle
  365. inter_rect_x1 = np.zeros((N, M), dtype=np.float32)
  366. inter_rect_y1 = np.zeros((N, M), dtype=np.float32)
  367. inter_rect_x2 = np.zeros((N, M), dtype=np.float32)
  368. inter_rect_y2 = np.zeros((N, M), dtype=np.float32)
  369. for i in range(len(box2)):
  370. inter_rect_x1[:, i] = np.maximum(b1_x1, b2_x1[i])
  371. inter_rect_y1[:, i] = np.maximum(b1_y1, b2_y1[i])
  372. inter_rect_x2[:, i] = np.minimum(b1_x2, b2_x2[i])
  373. inter_rect_y2[:, i] = np.minimum(b1_y2, b2_y2[i])
  374. # Intersection area
  375. inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(
  376. inter_rect_y2 - inter_rect_y1, 0)
  377. # Union Area
  378. b1_area = np.repeat(
  379. ((b1_x2 - b1_x1) * (b1_y2 - b1_y1)).reshape(-1, 1), M, axis=-1)
  380. b2_area = np.repeat(
  381. ((b2_x2 - b2_x1) * (b2_y2 - b2_y1)).reshape(1, -1), N, axis=0)
  382. ious = inter_area / (b1_area + b2_area - inter_area + eps)
  383. return ious