op_helper.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # this file contains helper methods for BBOX processing
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import numpy as np
  19. import random
  20. import math
  21. import cv2
  22. def meet_emit_constraint(src_bbox, sample_bbox):
  23. center_x = (src_bbox[2] + src_bbox[0]) / 2
  24. center_y = (src_bbox[3] + src_bbox[1]) / 2
  25. if center_x >= sample_bbox[0] and \
  26. center_x <= sample_bbox[2] and \
  27. center_y >= sample_bbox[1] and \
  28. center_y <= sample_bbox[3]:
  29. return True
  30. return False
  31. def clip_bbox(src_bbox):
  32. src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
  33. src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
  34. src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
  35. src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
  36. return src_bbox
  37. def bbox_area(src_bbox):
  38. if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
  39. return 0.
  40. else:
  41. width = src_bbox[2] - src_bbox[0]
  42. height = src_bbox[3] - src_bbox[1]
  43. return width * height
  44. def is_overlap(object_bbox, sample_bbox):
  45. if object_bbox[0] >= sample_bbox[2] or \
  46. object_bbox[2] <= sample_bbox[0] or \
  47. object_bbox[1] >= sample_bbox[3] or \
  48. object_bbox[3] <= sample_bbox[1]:
  49. return False
  50. else:
  51. return True
  52. def filter_and_process(sample_bbox,
  53. bboxes,
  54. labels,
  55. scores=None,
  56. keypoints=None):
  57. new_bboxes = []
  58. new_labels = []
  59. new_scores = []
  60. new_keypoints = []
  61. new_kp_ignore = []
  62. for i in range(len(bboxes)):
  63. new_bbox = [0, 0, 0, 0]
  64. obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
  65. if not meet_emit_constraint(obj_bbox, sample_bbox):
  66. continue
  67. if not is_overlap(obj_bbox, sample_bbox):
  68. continue
  69. sample_width = sample_bbox[2] - sample_bbox[0]
  70. sample_height = sample_bbox[3] - sample_bbox[1]
  71. new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
  72. new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
  73. new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
  74. new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
  75. new_bbox = clip_bbox(new_bbox)
  76. if bbox_area(new_bbox) > 0:
  77. new_bboxes.append(new_bbox)
  78. new_labels.append([labels[i][0]])
  79. if scores is not None:
  80. new_scores.append([scores[i][0]])
  81. if keypoints is not None:
  82. sample_keypoint = keypoints[0][i]
  83. for j in range(len(sample_keypoint)):
  84. kp_len = sample_height if j % 2 else sample_width
  85. sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
  86. sample_keypoint[j] = (
  87. sample_keypoint[j] - sample_coord) / kp_len
  88. sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
  89. new_keypoints.append(sample_keypoint)
  90. new_kp_ignore.append(keypoints[1][i])
  91. bboxes = np.array(new_bboxes)
  92. labels = np.array(new_labels)
  93. scores = np.array(new_scores)
  94. if keypoints is not None:
  95. keypoints = np.array(new_keypoints)
  96. new_kp_ignore = np.array(new_kp_ignore)
  97. return bboxes, labels, scores, (keypoints, new_kp_ignore)
  98. return bboxes, labels, scores
  99. def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
  100. new_bboxes = []
  101. new_labels = []
  102. new_scores = []
  103. for i, bbox in enumerate(bboxes):
  104. w = float((bbox[2] - bbox[0]) * target_size)
  105. h = float((bbox[3] - bbox[1]) * target_size)
  106. if w * h < float(min_size * min_size):
  107. continue
  108. else:
  109. new_bboxes.append(bbox)
  110. new_labels.append(labels[i])
  111. if scores is not None and scores.size != 0:
  112. new_scores.append(scores[i])
  113. bboxes = np.array(new_bboxes)
  114. labels = np.array(new_labels)
  115. scores = np.array(new_scores)
  116. return bboxes, labels, scores
  117. def generate_sample_bbox(sampler):
  118. scale = np.random.uniform(sampler[2], sampler[3])
  119. aspect_ratio = np.random.uniform(sampler[4], sampler[5])
  120. aspect_ratio = max(aspect_ratio, (scale**2.0))
  121. aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
  122. bbox_width = scale * (aspect_ratio**0.5)
  123. bbox_height = scale / (aspect_ratio**0.5)
  124. xmin_bound = 1 - bbox_width
  125. ymin_bound = 1 - bbox_height
  126. xmin = np.random.uniform(0, xmin_bound)
  127. ymin = np.random.uniform(0, ymin_bound)
  128. xmax = xmin + bbox_width
  129. ymax = ymin + bbox_height
  130. sampled_bbox = [xmin, ymin, xmax, ymax]
  131. return sampled_bbox
  132. def generate_sample_bbox_square(sampler, image_width, image_height):
  133. scale = np.random.uniform(sampler[2], sampler[3])
  134. aspect_ratio = np.random.uniform(sampler[4], sampler[5])
  135. aspect_ratio = max(aspect_ratio, (scale**2.0))
  136. aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
  137. bbox_width = scale * (aspect_ratio**0.5)
  138. bbox_height = scale / (aspect_ratio**0.5)
  139. if image_height < image_width:
  140. bbox_width = bbox_height * image_height / image_width
  141. else:
  142. bbox_height = bbox_width * image_width / image_height
  143. xmin_bound = 1 - bbox_width
  144. ymin_bound = 1 - bbox_height
  145. xmin = np.random.uniform(0, xmin_bound)
  146. ymin = np.random.uniform(0, ymin_bound)
  147. xmax = xmin + bbox_width
  148. ymax = ymin + bbox_height
  149. sampled_bbox = [xmin, ymin, xmax, ymax]
  150. return sampled_bbox
  151. def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
  152. resize_width):
  153. num_gt = len(bbox_labels)
  154. # np.random.randint range: [low, high)
  155. rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
  156. if num_gt != 0:
  157. norm_xmin = bbox_labels[rand_idx][0]
  158. norm_ymin = bbox_labels[rand_idx][1]
  159. norm_xmax = bbox_labels[rand_idx][2]
  160. norm_ymax = bbox_labels[rand_idx][3]
  161. xmin = norm_xmin * image_width
  162. ymin = norm_ymin * image_height
  163. wid = image_width * (norm_xmax - norm_xmin)
  164. hei = image_height * (norm_ymax - norm_ymin)
  165. range_size = 0
  166. area = wid * hei
  167. for scale_ind in range(0, len(scale_array) - 1):
  168. if area > scale_array[scale_ind] ** 2 and area < \
  169. scale_array[scale_ind + 1] ** 2:
  170. range_size = scale_ind + 1
  171. break
  172. if area > scale_array[len(scale_array) - 2]**2:
  173. range_size = len(scale_array) - 2
  174. scale_choose = 0.0
  175. if range_size == 0:
  176. rand_idx_size = 0
  177. else:
  178. # np.random.randint range: [low, high)
  179. rng_rand_size = np.random.randint(0, range_size + 1)
  180. rand_idx_size = rng_rand_size % (range_size + 1)
  181. if rand_idx_size == range_size:
  182. min_resize_val = scale_array[rand_idx_size] / 2.0
  183. max_resize_val = min(2.0 * scale_array[rand_idx_size],
  184. 2 * math.sqrt(wid * hei))
  185. scale_choose = random.uniform(min_resize_val, max_resize_val)
  186. else:
  187. min_resize_val = scale_array[rand_idx_size] / 2.0
  188. max_resize_val = 2.0 * scale_array[rand_idx_size]
  189. scale_choose = random.uniform(min_resize_val, max_resize_val)
  190. sample_bbox_size = wid * resize_width / scale_choose
  191. w_off_orig = 0.0
  192. h_off_orig = 0.0
  193. if sample_bbox_size < max(image_height, image_width):
  194. if wid <= sample_bbox_size:
  195. w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
  196. xmin)
  197. else:
  198. w_off_orig = np.random.uniform(xmin,
  199. xmin + wid - sample_bbox_size)
  200. if hei <= sample_bbox_size:
  201. h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
  202. ymin)
  203. else:
  204. h_off_orig = np.random.uniform(ymin,
  205. ymin + hei - sample_bbox_size)
  206. else:
  207. w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
  208. h_off_orig = np.random.uniform(image_height - sample_bbox_size,
  209. 0.0)
  210. w_off_orig = math.floor(w_off_orig)
  211. h_off_orig = math.floor(h_off_orig)
  212. # Figure out top left coordinates.
  213. w_off = float(w_off_orig / image_width)
  214. h_off = float(h_off_orig / image_height)
  215. sampled_bbox = [
  216. w_off, h_off, w_off + float(sample_bbox_size / image_width),
  217. h_off + float(sample_bbox_size / image_height)
  218. ]
  219. return sampled_bbox
  220. else:
  221. return 0
  222. def jaccard_overlap(sample_bbox, object_bbox):
  223. if sample_bbox[0] >= object_bbox[2] or \
  224. sample_bbox[2] <= object_bbox[0] or \
  225. sample_bbox[1] >= object_bbox[3] or \
  226. sample_bbox[3] <= object_bbox[1]:
  227. return 0
  228. intersect_xmin = max(sample_bbox[0], object_bbox[0])
  229. intersect_ymin = max(sample_bbox[1], object_bbox[1])
  230. intersect_xmax = min(sample_bbox[2], object_bbox[2])
  231. intersect_ymax = min(sample_bbox[3], object_bbox[3])
  232. intersect_size = (intersect_xmax - intersect_xmin) * (
  233. intersect_ymax - intersect_ymin)
  234. sample_bbox_size = bbox_area(sample_bbox)
  235. object_bbox_size = bbox_area(object_bbox)
  236. overlap = intersect_size / (
  237. sample_bbox_size + object_bbox_size - intersect_size)
  238. return overlap
  239. def intersect_bbox(bbox1, bbox2):
  240. if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
  241. bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
  242. intersection_box = [0.0, 0.0, 0.0, 0.0]
  243. else:
  244. intersection_box = [
  245. max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]),
  246. min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3])
  247. ]
  248. return intersection_box
  249. def bbox_coverage(bbox1, bbox2):
  250. inter_box = intersect_bbox(bbox1, bbox2)
  251. intersect_size = bbox_area(inter_box)
  252. if intersect_size > 0:
  253. bbox1_size = bbox_area(bbox1)
  254. return intersect_size / bbox1_size
  255. else:
  256. return 0.
  257. def satisfy_sample_constraint(sampler,
  258. sample_bbox,
  259. gt_bboxes,
  260. satisfy_all=False):
  261. if sampler[6] == 0 and sampler[7] == 0:
  262. return True
  263. satisfied = []
  264. for i in range(len(gt_bboxes)):
  265. object_bbox = [
  266. gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
  267. ]
  268. overlap = jaccard_overlap(sample_bbox, object_bbox)
  269. if sampler[6] != 0 and \
  270. overlap < sampler[6]:
  271. satisfied.append(False)
  272. continue
  273. if sampler[7] != 0 and \
  274. overlap > sampler[7]:
  275. satisfied.append(False)
  276. continue
  277. satisfied.append(True)
  278. if not satisfy_all:
  279. return True
  280. if satisfy_all:
  281. return np.all(satisfied)
  282. else:
  283. return False
  284. def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
  285. if sampler[6] == 0 and sampler[7] == 0:
  286. has_jaccard_overlap = False
  287. else:
  288. has_jaccard_overlap = True
  289. if sampler[8] == 0 and sampler[9] == 0:
  290. has_object_coverage = False
  291. else:
  292. has_object_coverage = True
  293. if not has_jaccard_overlap and not has_object_coverage:
  294. return True
  295. found = False
  296. for i in range(len(gt_bboxes)):
  297. object_bbox = [
  298. gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
  299. ]
  300. if has_jaccard_overlap:
  301. overlap = jaccard_overlap(sample_bbox, object_bbox)
  302. if sampler[6] != 0 and \
  303. overlap < sampler[6]:
  304. continue
  305. if sampler[7] != 0 and \
  306. overlap > sampler[7]:
  307. continue
  308. found = True
  309. if has_object_coverage:
  310. object_coverage = bbox_coverage(object_bbox, sample_bbox)
  311. if sampler[8] != 0 and \
  312. object_coverage < sampler[8]:
  313. continue
  314. if sampler[9] != 0 and \
  315. object_coverage > sampler[9]:
  316. continue
  317. found = True
  318. if found:
  319. return True
  320. return found
  321. def crop_image_sampling(img, sample_bbox, image_width, image_height,
  322. target_size):
  323. # no clipping here
  324. xmin = int(sample_bbox[0] * image_width)
  325. xmax = int(sample_bbox[2] * image_width)
  326. ymin = int(sample_bbox[1] * image_height)
  327. ymax = int(sample_bbox[3] * image_height)
  328. w_off = xmin
  329. h_off = ymin
  330. width = xmax - xmin
  331. height = ymax - ymin
  332. cross_xmin = max(0.0, float(w_off))
  333. cross_ymin = max(0.0, float(h_off))
  334. cross_xmax = min(float(w_off + width - 1.0), float(image_width))
  335. cross_ymax = min(float(h_off + height - 1.0), float(image_height))
  336. cross_width = cross_xmax - cross_xmin
  337. cross_height = cross_ymax - cross_ymin
  338. roi_xmin = 0 if w_off >= 0 else abs(w_off)
  339. roi_ymin = 0 if h_off >= 0 else abs(h_off)
  340. roi_width = cross_width
  341. roi_height = cross_height
  342. roi_y1 = int(roi_ymin)
  343. roi_y2 = int(roi_ymin + roi_height)
  344. roi_x1 = int(roi_xmin)
  345. roi_x2 = int(roi_xmin + roi_width)
  346. cross_y1 = int(cross_ymin)
  347. cross_y2 = int(cross_ymin + cross_height)
  348. cross_x1 = int(cross_xmin)
  349. cross_x2 = int(cross_xmin + cross_width)
  350. sample_img = np.zeros((height, width, 3))
  351. sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
  352. img[cross_y1: cross_y2, cross_x1: cross_x2]
  353. sample_img = cv2.resize(
  354. sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
  355. return sample_img
  356. def is_poly(segm):
  357. assert isinstance(segm, (list, dict)), \
  358. "Invalid segm type: {}".format(type(segm))
  359. return isinstance(segm, list)
  360. def gaussian_radius(bbox_size, min_overlap):
  361. height, width = bbox_size
  362. a1 = 1
  363. b1 = (height + width)
  364. c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
  365. sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
  366. radius1 = (b1 - sq1) / (2 * a1)
  367. a2 = 4
  368. b2 = 2 * (height + width)
  369. c2 = (1 - min_overlap) * width * height
  370. sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
  371. radius2 = (b2 - sq2) / (2 * a2)
  372. a3 = 4 * min_overlap
  373. b3 = -2 * min_overlap * (height + width)
  374. c3 = (min_overlap - 1) * width * height
  375. sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
  376. radius3 = (b3 + sq3) / (2 * a3)
  377. return min(radius1, radius2, radius3)
  378. def draw_gaussian(heatmap, center, radius, k=1, delte=6):
  379. diameter = 2 * radius + 1
  380. sigma = diameter / delte
  381. gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma)
  382. x, y = center
  383. height, width = heatmap.shape[0:2]
  384. left, right = min(x, radius), min(width - x, radius + 1)
  385. top, bottom = min(y, radius), min(height - y, radius + 1)
  386. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  387. masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
  388. radius + right]
  389. np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
  390. def gaussian2D(shape, sigma_x=1, sigma_y=1):
  391. m, n = [(ss - 1.) / 2. for ss in shape]
  392. y, x = np.ogrid[-m:m + 1, -n:n + 1]
  393. h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
  394. sigma_y)))
  395. h[h < np.finfo(h.dtype).eps * h.max()] = 0
  396. return h
  397. def transform_bbox(sample,
  398. M,
  399. w,
  400. h,
  401. area_thr=0.25,
  402. wh_thr=2,
  403. ar_thr=20,
  404. perspective=False):
  405. """
  406. transfrom bbox according to tranformation matrix M,
  407. refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
  408. """
  409. bbox = sample['gt_bbox']
  410. label = sample['gt_class']
  411. # rotate bbox
  412. n = len(bbox)
  413. xy = np.ones((n * 4, 3), dtype=np.float32)
  414. xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
  415. # xy = xy @ M.T
  416. xy = np.matmul(xy, M.T)
  417. if perspective:
  418. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
  419. else:
  420. xy = xy[:, :2].reshape(n, 8)
  421. # get new bboxes
  422. x = xy[:, [0, 2, 4, 6]]
  423. y = xy[:, [1, 3, 5, 7]]
  424. bbox = np.concatenate(
  425. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  426. # clip boxes
  427. mask = filter_bbox(bbox, w, h, area_thr)
  428. sample['gt_bbox'] = bbox[mask]
  429. sample['gt_class'] = sample['gt_class'][mask]
  430. if 'is_crowd' in sample:
  431. sample['is_crowd'] = sample['is_crowd'][mask]
  432. if 'difficult' in sample:
  433. sample['difficult'] = sample['difficult'][mask]
  434. return sample
  435. def filter_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20):
  436. """
  437. filter bbox, refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
  438. """
  439. # clip boxes
  440. area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
  441. bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w)
  442. bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h)
  443. # compute
  444. area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
  445. area_ratio = area2 / (area1 + 1e-16)
  446. wh = bbox[:, 2:4] - bbox[:, 0:2]
  447. ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16),
  448. wh[:, 0] / (wh[:, 1] + 1e-16))
  449. mask = (area_ratio > area_thr) & (
  450. (wh > wh_thr).all(1)) & (ar_ratio < ar_thr)
  451. return mask