mot_operators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from numbers import Integral
  22. import cv2
  23. import copy
  24. import numpy as np
  25. import random
  26. import math
  27. from .operators import BaseOperator, register_op
  28. from .batch_operators import Gt2TTFTarget
  29. from paddlex.ppdet.modeling.bbox_utils import bbox_iou_np_expand
  30. from paddlex.ppdet.utils.logger import setup_logger
  31. logger = setup_logger(__name__)
  32. __all__ = [
  33. 'RGBReverse', 'LetterBoxResize', 'MOTRandomAffine', 'Gt2JDETargetThres',
  34. 'Gt2JDETargetMax', 'Gt2FairMOTTarget'
  35. ]
  36. @register_op
  37. class RGBReverse(BaseOperator):
  38. """RGB to BGR, or BGR to RGB, sensitive to MOTRandomAffine
  39. """
  40. def __init__(self):
  41. super(RGBReverse, self).__init__()
  42. def apply(self, sample, context=None):
  43. im = sample['image']
  44. sample['image'] = np.ascontiguousarray(im[:, :, ::-1])
  45. return sample
  46. @register_op
  47. class LetterBoxResize(BaseOperator):
  48. def __init__(self, target_size):
  49. """
  50. Resize image to target size, convert normalized xywh to pixel xyxy
  51. format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
  52. Args:
  53. target_size (int|list): image target size.
  54. """
  55. super(LetterBoxResize, self).__init__()
  56. if not isinstance(target_size, (Integral, Sequence)):
  57. raise TypeError(
  58. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  59. format(type(target_size)))
  60. if isinstance(target_size, Integral):
  61. target_size = [target_size, target_size]
  62. self.target_size = target_size
  63. def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
  64. # letterbox: resize a rectangular image to a padded rectangular
  65. shape = img.shape[:2] # [height, width]
  66. ratio_h = float(height) / shape[0]
  67. ratio_w = float(width) / shape[1]
  68. ratio = min(ratio_h, ratio_w)
  69. new_shape = (round(shape[1] * ratio),
  70. round(shape[0] * ratio)) # [width, height]
  71. padw = (width - new_shape[0]) / 2
  72. padh = (height - new_shape[1]) / 2
  73. top, bottom = round(padh - 0.1), round(padh + 0.1)
  74. left, right = round(padw - 0.1), round(padw + 0.1)
  75. img = cv2.resize(
  76. img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
  77. img = cv2.copyMakeBorder(
  78. img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  79. value=color) # padded rectangular
  80. return img, ratio, padw, padh
  81. def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
  82. bboxes = bbox0.copy()
  83. bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
  84. bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
  85. bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
  86. bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
  87. return bboxes
  88. def apply(self, sample, context=None):
  89. """ Resize the image numpy.
  90. """
  91. im = sample['image']
  92. h, w = sample['im_shape']
  93. if not isinstance(im, np.ndarray):
  94. raise TypeError("{}: image type is not numpy.".format(self))
  95. if len(im.shape) != 3:
  96. from PIL import UnidentifiedImageError
  97. raise UnidentifiedImageError(
  98. '{}: image is not 3-dimensional.'.format(self))
  99. # apply image
  100. height, width = self.target_size
  101. img, ratio, padw, padh = self.apply_image(
  102. im, height=height, width=width)
  103. sample['image'] = img
  104. new_shape = (round(h * ratio), round(w * ratio))
  105. sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
  106. sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
  107. # apply bbox
  108. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  109. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
  110. padw, padh)
  111. return sample
  112. @register_op
  113. class MOTRandomAffine(BaseOperator):
  114. """
  115. Affine transform to image and coords to achieve the rotate, scale and
  116. shift effect for training image.
  117. Args:
  118. degrees (list[2]): the rotate range to apply, transform range is [min, max]
  119. translate (list[2]): the translate range to apply, transform range is [min, max]
  120. scale (list[2]): the scale range to apply, transform range is [min, max]
  121. shear (list[2]): the shear range to apply, transform range is [min, max]
  122. borderValue (list[3]): value used in case of a constant border when appling
  123. the perspective transformation
  124. reject_outside (bool): reject warped bounding bboxes outside of image
  125. Returns:
  126. records(dict): contain the image and coords after tranformed
  127. """
  128. def __init__(self,
  129. degrees=(-5, 5),
  130. translate=(0.10, 0.10),
  131. scale=(0.50, 1.20),
  132. shear=(-2, 2),
  133. borderValue=(127.5, 127.5, 127.5),
  134. reject_outside=True):
  135. super(MOTRandomAffine, self).__init__()
  136. self.degrees = degrees
  137. self.translate = translate
  138. self.scale = scale
  139. self.shear = shear
  140. self.borderValue = borderValue
  141. self.reject_outside = reject_outside
  142. def apply(self, sample, context=None):
  143. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  144. border = 0 # width of added border (optional)
  145. img = sample['image']
  146. height, width = img.shape[0], img.shape[1]
  147. # Rotation and Scale
  148. R = np.eye(3)
  149. a = random.random() * (self.degrees[1] - self.degrees[0]
  150. ) + self.degrees[0]
  151. s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
  152. R[:2] = cv2.getRotationMatrix2D(
  153. angle=a, center=(width / 2, height / 2), scale=s)
  154. # Translation
  155. T = np.eye(3)
  156. T[0, 2] = (
  157. random.random() * 2 - 1
  158. ) * self.translate[0] * height + border # x translation (pixels)
  159. T[1, 2] = (
  160. random.random() * 2 - 1
  161. ) * self.translate[1] * width + border # y translation (pixels)
  162. # Shear
  163. S = np.eye(3)
  164. S[0, 1] = math.tan((random.random() *
  165. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  166. math.pi / 180) # x shear (deg)
  167. S[1, 0] = math.tan((random.random() *
  168. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  169. math.pi / 180) # y shear (deg)
  170. M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
  171. imw = cv2.warpPerspective(
  172. img,
  173. M,
  174. dsize=(width, height),
  175. flags=cv2.INTER_LINEAR,
  176. borderValue=self.borderValue) # BGR order borderValue
  177. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  178. targets = sample['gt_bbox']
  179. n = targets.shape[0]
  180. points = targets.copy()
  181. area0 = (points[:, 2] - points[:, 0]) * (
  182. points[:, 3] - points[:, 1])
  183. # warp points
  184. xy = np.ones((n * 4, 3))
  185. xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  186. n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  187. xy = (xy @M.T)[:, :2].reshape(n, 8)
  188. # create new boxes
  189. x = xy[:, [0, 2, 4, 6]]
  190. y = xy[:, [1, 3, 5, 7]]
  191. xy = np.concatenate(
  192. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  193. # apply angle-based reduction
  194. radians = a * math.pi / 180
  195. reduction = max(abs(math.sin(radians)),
  196. abs(math.cos(radians)))**0.5
  197. x = (xy[:, 2] + xy[:, 0]) / 2
  198. y = (xy[:, 3] + xy[:, 1]) / 2
  199. w = (xy[:, 2] - xy[:, 0]) * reduction
  200. h = (xy[:, 3] - xy[:, 1]) * reduction
  201. xy = np.concatenate(
  202. (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  203. # reject warped points outside of image
  204. if self.reject_outside:
  205. np.clip(xy[:, 0], 0, width, out=xy[:, 0])
  206. np.clip(xy[:, 2], 0, width, out=xy[:, 2])
  207. np.clip(xy[:, 1], 0, height, out=xy[:, 1])
  208. np.clip(xy[:, 3], 0, height, out=xy[:, 3])
  209. w = xy[:, 2] - xy[:, 0]
  210. h = xy[:, 3] - xy[:, 1]
  211. area = w * h
  212. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
  213. i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
  214. if sum(i) > 0:
  215. sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
  216. sample['gt_class'] = sample['gt_class'][i]
  217. if 'difficult' in sample:
  218. sample['difficult'] = sample['difficult'][i]
  219. if 'gt_ide' in sample:
  220. sample['gt_ide'] = sample['gt_ide'][i]
  221. if 'is_crowd' in sample:
  222. sample['is_crowd'] = sample['is_crowd'][i]
  223. sample['image'] = imw
  224. return sample
  225. else:
  226. return sample
  227. @register_op
  228. class Gt2JDETargetThres(BaseOperator):
  229. __shared__ = ['num_classes']
  230. """
  231. Generate JDE targets by groud truth data when training
  232. Args:
  233. anchors (list): anchors of JDE model
  234. anchor_masks (list): anchor_masks of JDE model
  235. downsample_ratios (list): downsample ratios of JDE model
  236. ide_thresh (float): thresh of identity, higher is groud truth
  237. fg_thresh (float): thresh of foreground, higher is foreground
  238. bg_thresh (float): thresh of background, lower is background
  239. num_classes (int): number of classes
  240. """
  241. def __init__(self,
  242. anchors,
  243. anchor_masks,
  244. downsample_ratios,
  245. ide_thresh=0.5,
  246. fg_thresh=0.5,
  247. bg_thresh=0.4,
  248. num_classes=1):
  249. super(Gt2JDETargetThres, self).__init__()
  250. self.anchors = anchors
  251. self.anchor_masks = anchor_masks
  252. self.downsample_ratios = downsample_ratios
  253. self.ide_thresh = ide_thresh
  254. self.fg_thresh = fg_thresh
  255. self.bg_thresh = bg_thresh
  256. self.num_classes = num_classes
  257. def generate_anchor(self, nGh, nGw, anchor_hw):
  258. nA = len(anchor_hw)
  259. yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
  260. mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
  261. mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
  262. anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
  263. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
  264. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
  265. anchor_mesh = np.concatenate(
  266. [mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
  267. return anchor_mesh
  268. def encode_delta(self, gt_box_list, fg_anchor_list):
  269. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  270. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  271. gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
  272. gt_box_list[:, 2], gt_box_list[:, 3]
  273. dx = (gx - px) / pw
  274. dy = (gy - py) / ph
  275. dw = np.log(gw / pw)
  276. dh = np.log(gh / ph)
  277. return np.stack([dx, dy, dw, dh], axis=1)
  278. def pad_box(self, sample, num_max):
  279. assert 'gt_bbox' in sample
  280. bbox = sample['gt_bbox']
  281. gt_num = len(bbox)
  282. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  283. if gt_num > 0:
  284. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  285. sample['gt_bbox'] = pad_bbox
  286. if 'gt_score' in sample:
  287. pad_score = np.zeros((num_max, ), dtype=np.float32)
  288. if gt_num > 0:
  289. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  290. sample['gt_score'] = pad_score
  291. if 'difficult' in sample:
  292. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  293. if gt_num > 0:
  294. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  295. sample['difficult'] = pad_diff
  296. if 'is_crowd' in sample:
  297. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  298. if gt_num > 0:
  299. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  300. sample['is_crowd'] = pad_crowd
  301. if 'gt_ide' in sample:
  302. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  303. if gt_num > 0:
  304. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  305. sample['gt_ide'] = pad_ide
  306. return sample
  307. def __call__(self, samples, context=None):
  308. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  309. "anchor_masks', and 'downsample_ratios' should have same length."
  310. h, w = samples[0]['image'].shape[1:3]
  311. num_max = 0
  312. for sample in samples:
  313. num_max = max(num_max, len(sample['gt_bbox']))
  314. for sample in samples:
  315. gt_bbox = sample['gt_bbox']
  316. gt_ide = sample['gt_ide']
  317. for i, (anchor_hw, downsample_ratio
  318. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  319. anchor_hw = np.array(
  320. anchor_hw, dtype=np.float32) / downsample_ratio
  321. nA = len(anchor_hw)
  322. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  323. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  324. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  325. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  326. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  327. gxy[:, 0] = gxy[:, 0] * nGw
  328. gxy[:, 1] = gxy[:, 1] * nGh
  329. gwh[:, 0] = gwh[:, 0] * nGw
  330. gwh[:, 1] = gwh[:, 1] * nGh
  331. gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
  332. gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
  333. tboxes = np.concatenate([gxy, gwh], axis=1)
  334. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
  335. anchor_list = np.transpose(anchor_mesh,
  336. (0, 2, 3, 1)).reshape(-1, 4)
  337. iou_pdist = bbox_iou_np_expand(
  338. anchor_list, tboxes, x1y1x2y2=False)
  339. iou_max = np.max(iou_pdist, axis=1)
  340. max_gt_index = np.argmax(iou_pdist, axis=1)
  341. iou_map = iou_max.reshape(nA, nGh, nGw)
  342. gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
  343. id_index = iou_map > self.ide_thresh
  344. fg_index = iou_map > self.fg_thresh
  345. bg_index = iou_map < self.bg_thresh
  346. ign_index = (iou_map < self.fg_thresh) * (
  347. iou_map > self.bg_thresh)
  348. tconf[fg_index] = 1
  349. tconf[bg_index] = 0
  350. tconf[ign_index] = -1
  351. gt_index = gt_index_map[fg_index]
  352. gt_box_list = tboxes[gt_index]
  353. gt_id_list = gt_ide[gt_index_map[id_index]]
  354. if np.sum(fg_index) > 0:
  355. tid[id_index] = gt_id_list
  356. fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
  357. 4)[fg_index]
  358. delta_target = self.encode_delta(gt_box_list,
  359. fg_anchor_list)
  360. tbox[fg_index] = delta_target
  361. sample['tbox{}'.format(i)] = tbox
  362. sample['tconf{}'.format(i)] = tconf
  363. sample['tide{}'.format(i)] = tid
  364. sample.pop('gt_class')
  365. sample = self.pad_box(sample, num_max)
  366. return samples
  367. @register_op
  368. class Gt2JDETargetMax(BaseOperator):
  369. __shared__ = ['num_classes']
  370. """
  371. Generate JDE targets by groud truth data when evaluating
  372. Args:
  373. anchors (list): anchors of JDE model
  374. anchor_masks (list): anchor_masks of JDE model
  375. downsample_ratios (list): downsample ratios of JDE model
  376. max_iou_thresh (float): iou thresh for high quality anchor
  377. num_classes (int): number of classes
  378. """
  379. def __init__(self,
  380. anchors,
  381. anchor_masks,
  382. downsample_ratios,
  383. max_iou_thresh=0.60,
  384. num_classes=1):
  385. super(Gt2JDETargetMax, self).__init__()
  386. self.anchors = anchors
  387. self.anchor_masks = anchor_masks
  388. self.downsample_ratios = downsample_ratios
  389. self.max_iou_thresh = max_iou_thresh
  390. self.num_classes = num_classes
  391. def __call__(self, samples, context=None):
  392. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  393. "anchor_masks', and 'downsample_ratios' should have same length."
  394. h, w = samples[0]['image'].shape[1:3]
  395. for sample in samples:
  396. gt_bbox = sample['gt_bbox']
  397. gt_ide = sample['gt_ide']
  398. for i, (anchor_hw, downsample_ratio
  399. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  400. anchor_hw = np.array(
  401. anchor_hw, dtype=np.float32) / downsample_ratio
  402. nA = len(anchor_hw)
  403. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  404. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  405. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  406. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  407. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  408. gxy[:, 0] = gxy[:, 0] * nGw
  409. gxy[:, 1] = gxy[:, 1] * nGh
  410. gwh[:, 0] = gwh[:, 0] * nGw
  411. gwh[:, 1] = gwh[:, 1] * nGh
  412. gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
  413. gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
  414. # iou of targets-anchors (using wh only)
  415. box1 = gwh
  416. box2 = anchor_hw[:, None, :]
  417. inter_area = np.minimum(box1, box2).prod(2)
  418. iou = inter_area / (
  419. box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
  420. # Select best iou_pred and anchor
  421. iou_best = iou.max(0) # best anchor [0-2] for each target
  422. a = np.argmax(iou, axis=0)
  423. # Select best unique target-anchor combinations
  424. iou_order = np.argsort(-iou_best) # best to worst
  425. # Unique anchor selection
  426. u = np.stack((gi, gj, a), 0)[:, iou_order]
  427. _, first_unique = np.unique(u, axis=1, return_index=True)
  428. mask = iou_order[first_unique]
  429. # best anchor must share significant commonality (iou) with target
  430. # TODO: examine arbitrary threshold
  431. idx = mask[iou_best[mask] > self.max_iou_thresh]
  432. if len(idx) > 0:
  433. a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
  434. t_box = gt_bbox[idx]
  435. t_id = gt_ide[idx]
  436. if len(t_box.shape) == 1:
  437. t_box = t_box.reshape(1, 4)
  438. gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
  439. gxy[:, 0] = gxy[:, 0] * nGw
  440. gxy[:, 1] = gxy[:, 1] * nGh
  441. gwh[:, 0] = gwh[:, 0] * nGw
  442. gwh[:, 1] = gwh[:, 1] * nGh
  443. # XY coordinates
  444. tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
  445. # Width and height in yolo method
  446. tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(
  447. gwh / anchor_hw[a_i])
  448. tconf[a_i, gj_i, gi_i] = 1
  449. tid[a_i, gj_i, gi_i] = t_id
  450. sample['tbox{}'.format(i)] = tbox
  451. sample['tconf{}'.format(i)] = tconf
  452. sample['tide{}'.format(i)] = tid
  453. class Gt2FairMOTTarget(Gt2TTFTarget):
  454. __shared__ = ['num_classes']
  455. """
  456. Generate FairMOT targets by ground truth data.
  457. Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
  458. 1. the gaussian kernal radius to generate a heatmap.
  459. 2. the targets needed during traing.
  460. Args:
  461. num_classes(int): the number of classes.
  462. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  463. max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
  464. """
  465. def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
  466. super(Gt2TTFTarget, self).__init__()
  467. self.down_ratio = down_ratio
  468. self.num_classes = num_classes
  469. self.max_objs = max_objs
  470. def __call__(self, samples, context=None):
  471. for b_id, sample in enumerate(samples):
  472. output_h = sample['image'].shape[1] // self.down_ratio
  473. output_w = sample['image'].shape[2] // self.down_ratio
  474. heatmap = np.zeros(
  475. (self.num_classes, output_h, output_w), dtype='float32')
  476. bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
  477. center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
  478. index = np.zeros((self.max_objs, ), dtype=np.int64)
  479. index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
  480. reid = np.zeros((self.max_objs, ), dtype=np.int64)
  481. bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
  482. gt_bbox = sample['gt_bbox']
  483. gt_class = sample['gt_class']
  484. gt_ide = sample['gt_ide']
  485. for k in range(len(gt_bbox)):
  486. cls_id = gt_class[k][0]
  487. bbox = gt_bbox[k]
  488. ide = gt_ide[k][0]
  489. bbox[[0, 2]] = bbox[[0, 2]] * output_w
  490. bbox[[1, 3]] = bbox[[1, 3]] * output_h
  491. bbox_amodal = copy.deepcopy(bbox)
  492. bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
  493. bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
  494. bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
  495. bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
  496. bbox[0] = np.clip(bbox[0], 0, output_w - 1)
  497. bbox[1] = np.clip(bbox[1], 0, output_h - 1)
  498. h = bbox[3]
  499. w = bbox[2]
  500. bbox_xy = copy.deepcopy(bbox)
  501. bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
  502. bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
  503. bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
  504. bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
  505. if h > 0 and w > 0:
  506. radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
  507. radius = max(0, int(radius))
  508. ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
  509. ct_int = ct.astype(np.int32)
  510. self.draw_truncate_gaussian(heatmap[cls_id], ct_int,
  511. radius, radius)
  512. bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
  513. bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
  514. index[k] = ct_int[1] * output_w + ct_int[0]
  515. center_offset[k] = ct - ct_int
  516. index_mask[k] = 1
  517. reid[k] = ide
  518. bbox_xys[k] = bbox_xy
  519. sample['heatmap'] = heatmap
  520. sample['index'] = index
  521. sample['offset'] = center_offset
  522. sample['size'] = bbox_size
  523. sample['index_mask'] = index_mask
  524. sample['reid'] = reid
  525. sample['bbox_xys'] = bbox_xys
  526. sample.pop('is_crowd', None)
  527. sample.pop('difficult', None)
  528. sample.pop('gt_class', None)
  529. sample.pop('gt_bbox', None)
  530. sample.pop('gt_score', None)
  531. sample.pop('gt_ide', None)
  532. return samples
  533. def gaussian_radius(self, det_size, min_overlap=0.7):
  534. height, width = det_size
  535. a1 = 1
  536. b1 = (height + width)
  537. c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
  538. sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
  539. r1 = (b1 + sq1) / 2
  540. a2 = 4
  541. b2 = 2 * (height + width)
  542. c2 = (1 - min_overlap) * width * height
  543. sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
  544. r2 = (b2 + sq2) / 2
  545. a3 = 4 * min_overlap
  546. b3 = -2 * min_overlap * (height + width)
  547. c3 = (min_overlap - 1) * width * height
  548. sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
  549. r3 = (b3 + sq3) / 2
  550. return min(r1, r2, r3)