mot_operators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from numbers import Integral
  22. import cv2
  23. import copy
  24. import numpy as np
  25. import random
  26. import math
  27. from .operators import BaseOperator, register_op
  28. from .batch_operators import Gt2TTFTarget
  29. from paddlex.ppdet.modeling.bbox_utils import bbox_iou_np_expand
  30. from paddlex.ppdet.utils.logger import setup_logger
  31. from .op_helper import gaussian_radius
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'RGBReverse', 'LetterBoxResize', 'MOTRandomAffine', 'Gt2JDETargetThres',
  35. 'Gt2JDETargetMax', 'Gt2FairMOTTarget'
  36. ]
  37. @register_op
  38. class RGBReverse(BaseOperator):
  39. """RGB to BGR, or BGR to RGB, sensitive to MOTRandomAffine
  40. """
  41. def __init__(self):
  42. super(RGBReverse, self).__init__()
  43. def apply(self, sample, context=None):
  44. im = sample['image']
  45. sample['image'] = np.ascontiguousarray(im[:, :, ::-1])
  46. return sample
  47. @register_op
  48. class LetterBoxResize(BaseOperator):
  49. def __init__(self, target_size):
  50. """
  51. Resize image to target size, convert normalized xywh to pixel xyxy
  52. format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
  53. Args:
  54. target_size (int|list): image target size.
  55. """
  56. super(LetterBoxResize, self).__init__()
  57. if not isinstance(target_size, (Integral, Sequence)):
  58. raise TypeError(
  59. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  60. format(type(target_size)))
  61. if isinstance(target_size, Integral):
  62. target_size = [target_size, target_size]
  63. self.target_size = target_size
  64. def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
  65. # letterbox: resize a rectangular image to a padded rectangular
  66. shape = img.shape[:2] # [height, width]
  67. ratio_h = float(height) / shape[0]
  68. ratio_w = float(width) / shape[1]
  69. ratio = min(ratio_h, ratio_w)
  70. new_shape = (round(shape[1] * ratio),
  71. round(shape[0] * ratio)) # [width, height]
  72. padw = (width - new_shape[0]) / 2
  73. padh = (height - new_shape[1]) / 2
  74. top, bottom = round(padh - 0.1), round(padh + 0.1)
  75. left, right = round(padw - 0.1), round(padw + 0.1)
  76. img = cv2.resize(
  77. img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
  78. img = cv2.copyMakeBorder(
  79. img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  80. value=color) # padded rectangular
  81. return img, ratio, padw, padh
  82. def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
  83. bboxes = bbox0.copy()
  84. bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
  85. bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
  86. bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
  87. bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
  88. return bboxes
  89. def apply(self, sample, context=None):
  90. """ Resize the image numpy.
  91. """
  92. im = sample['image']
  93. h, w = sample['im_shape']
  94. if not isinstance(im, np.ndarray):
  95. raise TypeError("{}: image type is not numpy.".format(self))
  96. if len(im.shape) != 3:
  97. from PIL import UnidentifiedImageError
  98. raise UnidentifiedImageError(
  99. '{}: image is not 3-dimensional.'.format(self))
  100. # apply image
  101. height, width = self.target_size
  102. img, ratio, padw, padh = self.apply_image(
  103. im, height=height, width=width)
  104. sample['image'] = img
  105. new_shape = (round(h * ratio), round(w * ratio))
  106. sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
  107. sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
  108. # apply bbox
  109. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  110. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
  111. padw, padh)
  112. return sample
  113. @register_op
  114. class MOTRandomAffine(BaseOperator):
  115. """
  116. Affine transform to image and coords to achieve the rotate, scale and
  117. shift effect for training image.
  118. Args:
  119. degrees (list[2]): the rotate range to apply, transform range is [min, max]
  120. translate (list[2]): the translate range to apply, transform range is [min, max]
  121. scale (list[2]): the scale range to apply, transform range is [min, max]
  122. shear (list[2]): the shear range to apply, transform range is [min, max]
  123. borderValue (list[3]): value used in case of a constant border when appling
  124. the perspective transformation
  125. reject_outside (bool): reject warped bounding bboxes outside of image
  126. Returns:
  127. records(dict): contain the image and coords after tranformed
  128. """
  129. def __init__(self,
  130. degrees=(-5, 5),
  131. translate=(0.10, 0.10),
  132. scale=(0.50, 1.20),
  133. shear=(-2, 2),
  134. borderValue=(127.5, 127.5, 127.5),
  135. reject_outside=True):
  136. super(MOTRandomAffine, self).__init__()
  137. self.degrees = degrees
  138. self.translate = translate
  139. self.scale = scale
  140. self.shear = shear
  141. self.borderValue = borderValue
  142. self.reject_outside = reject_outside
  143. def apply(self, sample, context=None):
  144. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  145. border = 0 # width of added border (optional)
  146. img = sample['image']
  147. height, width = img.shape[0], img.shape[1]
  148. # Rotation and Scale
  149. R = np.eye(3)
  150. a = random.random() * (self.degrees[1] - self.degrees[0]
  151. ) + self.degrees[0]
  152. s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
  153. R[:2] = cv2.getRotationMatrix2D(
  154. angle=a, center=(width / 2, height / 2), scale=s)
  155. # Translation
  156. T = np.eye(3)
  157. T[0, 2] = (
  158. random.random() * 2 - 1
  159. ) * self.translate[0] * height + border # x translation (pixels)
  160. T[1, 2] = (
  161. random.random() * 2 - 1
  162. ) * self.translate[1] * width + border # y translation (pixels)
  163. # Shear
  164. S = np.eye(3)
  165. S[0, 1] = math.tan((random.random() *
  166. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  167. math.pi / 180) # x shear (deg)
  168. S[1, 0] = math.tan((random.random() *
  169. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  170. math.pi / 180) # y shear (deg)
  171. M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
  172. imw = cv2.warpPerspective(
  173. img,
  174. M,
  175. dsize=(width, height),
  176. flags=cv2.INTER_LINEAR,
  177. borderValue=self.borderValue) # BGR order borderValue
  178. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  179. targets = sample['gt_bbox']
  180. n = targets.shape[0]
  181. points = targets.copy()
  182. area0 = (points[:, 2] - points[:, 0]) * (
  183. points[:, 3] - points[:, 1])
  184. # warp points
  185. xy = np.ones((n * 4, 3))
  186. xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  187. n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  188. xy = (xy @M.T)[:, :2].reshape(n, 8)
  189. # create new boxes
  190. x = xy[:, [0, 2, 4, 6]]
  191. y = xy[:, [1, 3, 5, 7]]
  192. xy = np.concatenate(
  193. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  194. # apply angle-based reduction
  195. radians = a * math.pi / 180
  196. reduction = max(abs(math.sin(radians)),
  197. abs(math.cos(radians)))**0.5
  198. x = (xy[:, 2] + xy[:, 0]) / 2
  199. y = (xy[:, 3] + xy[:, 1]) / 2
  200. w = (xy[:, 2] - xy[:, 0]) * reduction
  201. h = (xy[:, 3] - xy[:, 1]) * reduction
  202. xy = np.concatenate(
  203. (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  204. # reject warped points outside of image
  205. if self.reject_outside:
  206. np.clip(xy[:, 0], 0, width, out=xy[:, 0])
  207. np.clip(xy[:, 2], 0, width, out=xy[:, 2])
  208. np.clip(xy[:, 1], 0, height, out=xy[:, 1])
  209. np.clip(xy[:, 3], 0, height, out=xy[:, 3])
  210. w = xy[:, 2] - xy[:, 0]
  211. h = xy[:, 3] - xy[:, 1]
  212. area = w * h
  213. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
  214. i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
  215. if sum(i) > 0:
  216. sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
  217. sample['gt_class'] = sample['gt_class'][i]
  218. if 'difficult' in sample:
  219. sample['difficult'] = sample['difficult'][i]
  220. if 'gt_ide' in sample:
  221. sample['gt_ide'] = sample['gt_ide'][i]
  222. if 'is_crowd' in sample:
  223. sample['is_crowd'] = sample['is_crowd'][i]
  224. sample['image'] = imw
  225. return sample
  226. else:
  227. return sample
  228. @register_op
  229. class Gt2JDETargetThres(BaseOperator):
  230. __shared__ = ['num_classes']
  231. """
  232. Generate JDE targets by groud truth data when training
  233. Args:
  234. anchors (list): anchors of JDE model
  235. anchor_masks (list): anchor_masks of JDE model
  236. downsample_ratios (list): downsample ratios of JDE model
  237. ide_thresh (float): thresh of identity, higher is groud truth
  238. fg_thresh (float): thresh of foreground, higher is foreground
  239. bg_thresh (float): thresh of background, lower is background
  240. num_classes (int): number of classes
  241. """
  242. def __init__(self,
  243. anchors,
  244. anchor_masks,
  245. downsample_ratios,
  246. ide_thresh=0.5,
  247. fg_thresh=0.5,
  248. bg_thresh=0.4,
  249. num_classes=1):
  250. super(Gt2JDETargetThres, self).__init__()
  251. self.anchors = anchors
  252. self.anchor_masks = anchor_masks
  253. self.downsample_ratios = downsample_ratios
  254. self.ide_thresh = ide_thresh
  255. self.fg_thresh = fg_thresh
  256. self.bg_thresh = bg_thresh
  257. self.num_classes = num_classes
  258. def generate_anchor(self, nGh, nGw, anchor_hw):
  259. nA = len(anchor_hw)
  260. yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
  261. mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
  262. mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
  263. anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
  264. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
  265. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
  266. anchor_mesh = np.concatenate(
  267. [mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
  268. return anchor_mesh
  269. def encode_delta(self, gt_box_list, fg_anchor_list):
  270. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  271. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  272. gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
  273. gt_box_list[:, 2], gt_box_list[:, 3]
  274. dx = (gx - px) / pw
  275. dy = (gy - py) / ph
  276. dw = np.log(gw / pw)
  277. dh = np.log(gh / ph)
  278. return np.stack([dx, dy, dw, dh], axis=1)
  279. def pad_box(self, sample, num_max):
  280. assert 'gt_bbox' in sample
  281. bbox = sample['gt_bbox']
  282. gt_num = len(bbox)
  283. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  284. if gt_num > 0:
  285. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  286. sample['gt_bbox'] = pad_bbox
  287. if 'gt_score' in sample:
  288. pad_score = np.zeros((num_max, ), dtype=np.float32)
  289. if gt_num > 0:
  290. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  291. sample['gt_score'] = pad_score
  292. if 'difficult' in sample:
  293. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  294. if gt_num > 0:
  295. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  296. sample['difficult'] = pad_diff
  297. if 'is_crowd' in sample:
  298. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  299. if gt_num > 0:
  300. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  301. sample['is_crowd'] = pad_crowd
  302. if 'gt_ide' in sample:
  303. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  304. if gt_num > 0:
  305. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  306. sample['gt_ide'] = pad_ide
  307. return sample
  308. def __call__(self, samples, context=None):
  309. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  310. "anchor_masks', and 'downsample_ratios' should have same length."
  311. h, w = samples[0]['image'].shape[1:3]
  312. num_max = 0
  313. for sample in samples:
  314. num_max = max(num_max, len(sample['gt_bbox']))
  315. for sample in samples:
  316. gt_bbox = sample['gt_bbox']
  317. gt_ide = sample['gt_ide']
  318. for i, (anchor_hw, downsample_ratio
  319. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  320. anchor_hw = np.array(
  321. anchor_hw, dtype=np.float32) / downsample_ratio
  322. nA = len(anchor_hw)
  323. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  324. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  325. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  326. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  327. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  328. gxy[:, 0] = gxy[:, 0] * nGw
  329. gxy[:, 1] = gxy[:, 1] * nGh
  330. gwh[:, 0] = gwh[:, 0] * nGw
  331. gwh[:, 1] = gwh[:, 1] * nGh
  332. gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
  333. gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
  334. tboxes = np.concatenate([gxy, gwh], axis=1)
  335. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
  336. anchor_list = np.transpose(anchor_mesh,
  337. (0, 2, 3, 1)).reshape(-1, 4)
  338. iou_pdist = bbox_iou_np_expand(
  339. anchor_list, tboxes, x1y1x2y2=False)
  340. iou_max = np.max(iou_pdist, axis=1)
  341. max_gt_index = np.argmax(iou_pdist, axis=1)
  342. iou_map = iou_max.reshape(nA, nGh, nGw)
  343. gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
  344. id_index = iou_map > self.ide_thresh
  345. fg_index = iou_map > self.fg_thresh
  346. bg_index = iou_map < self.bg_thresh
  347. ign_index = (iou_map < self.fg_thresh) * (
  348. iou_map > self.bg_thresh)
  349. tconf[fg_index] = 1
  350. tconf[bg_index] = 0
  351. tconf[ign_index] = -1
  352. gt_index = gt_index_map[fg_index]
  353. gt_box_list = tboxes[gt_index]
  354. gt_id_list = gt_ide[gt_index_map[id_index]]
  355. if np.sum(fg_index) > 0:
  356. tid[id_index] = gt_id_list
  357. fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
  358. 4)[fg_index]
  359. delta_target = self.encode_delta(gt_box_list,
  360. fg_anchor_list)
  361. tbox[fg_index] = delta_target
  362. sample['tbox{}'.format(i)] = tbox
  363. sample['tconf{}'.format(i)] = tconf
  364. sample['tide{}'.format(i)] = tid
  365. sample.pop('gt_class')
  366. sample = self.pad_box(sample, num_max)
  367. return samples
  368. @register_op
  369. class Gt2JDETargetMax(BaseOperator):
  370. __shared__ = ['num_classes']
  371. """
  372. Generate JDE targets by groud truth data when evaluating
  373. Args:
  374. anchors (list): anchors of JDE model
  375. anchor_masks (list): anchor_masks of JDE model
  376. downsample_ratios (list): downsample ratios of JDE model
  377. max_iou_thresh (float): iou thresh for high quality anchor
  378. num_classes (int): number of classes
  379. """
  380. def __init__(self,
  381. anchors,
  382. anchor_masks,
  383. downsample_ratios,
  384. max_iou_thresh=0.60,
  385. num_classes=1):
  386. super(Gt2JDETargetMax, self).__init__()
  387. self.anchors = anchors
  388. self.anchor_masks = anchor_masks
  389. self.downsample_ratios = downsample_ratios
  390. self.max_iou_thresh = max_iou_thresh
  391. self.num_classes = num_classes
  392. def __call__(self, samples, context=None):
  393. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  394. "anchor_masks', and 'downsample_ratios' should have same length."
  395. h, w = samples[0]['image'].shape[1:3]
  396. for sample in samples:
  397. gt_bbox = sample['gt_bbox']
  398. gt_ide = sample['gt_ide']
  399. for i, (anchor_hw, downsample_ratio
  400. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  401. anchor_hw = np.array(
  402. anchor_hw, dtype=np.float32) / downsample_ratio
  403. nA = len(anchor_hw)
  404. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  405. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  406. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  407. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  408. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  409. gxy[:, 0] = gxy[:, 0] * nGw
  410. gxy[:, 1] = gxy[:, 1] * nGh
  411. gwh[:, 0] = gwh[:, 0] * nGw
  412. gwh[:, 1] = gwh[:, 1] * nGh
  413. gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
  414. gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
  415. # iou of targets-anchors (using wh only)
  416. box1 = gwh
  417. box2 = anchor_hw[:, None, :]
  418. inter_area = np.minimum(box1, box2).prod(2)
  419. iou = inter_area / (
  420. box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
  421. # Select best iou_pred and anchor
  422. iou_best = iou.max(0) # best anchor [0-2] for each target
  423. a = np.argmax(iou, axis=0)
  424. # Select best unique target-anchor combinations
  425. iou_order = np.argsort(-iou_best) # best to worst
  426. # Unique anchor selection
  427. u = np.stack((gi, gj, a), 0)[:, iou_order]
  428. _, first_unique = np.unique(u, axis=1, return_index=True)
  429. mask = iou_order[first_unique]
  430. # best anchor must share significant commonality (iou) with target
  431. # TODO: examine arbitrary threshold
  432. idx = mask[iou_best[mask] > self.max_iou_thresh]
  433. if len(idx) > 0:
  434. a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
  435. t_box = gt_bbox[idx]
  436. t_id = gt_ide[idx]
  437. if len(t_box.shape) == 1:
  438. t_box = t_box.reshape(1, 4)
  439. gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
  440. gxy[:, 0] = gxy[:, 0] * nGw
  441. gxy[:, 1] = gxy[:, 1] * nGh
  442. gwh[:, 0] = gwh[:, 0] * nGw
  443. gwh[:, 1] = gwh[:, 1] * nGh
  444. # XY coordinates
  445. tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
  446. # Width and height in yolo method
  447. tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(
  448. gwh / anchor_hw[a_i])
  449. tconf[a_i, gj_i, gi_i] = 1
  450. tid[a_i, gj_i, gi_i] = t_id
  451. sample['tbox{}'.format(i)] = tbox
  452. sample['tconf{}'.format(i)] = tconf
  453. sample['tide{}'.format(i)] = tid
  454. class Gt2FairMOTTarget(Gt2TTFTarget):
  455. __shared__ = ['num_classes']
  456. """
  457. Generate FairMOT targets by ground truth data.
  458. Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
  459. 1. the gaussian kernal radius to generate a heatmap.
  460. 2. the targets needed during training.
  461. Args:
  462. num_classes(int): the number of classes.
  463. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  464. max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
  465. """
  466. def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
  467. super(Gt2TTFTarget, self).__init__()
  468. self.down_ratio = down_ratio
  469. self.num_classes = num_classes
  470. self.max_objs = max_objs
  471. def __call__(self, samples, context=None):
  472. for b_id, sample in enumerate(samples):
  473. output_h = sample['image'].shape[1] // self.down_ratio
  474. output_w = sample['image'].shape[2] // self.down_ratio
  475. heatmap = np.zeros(
  476. (self.num_classes, output_h, output_w), dtype='float32')
  477. bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
  478. center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
  479. index = np.zeros((self.max_objs, ), dtype=np.int64)
  480. index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
  481. reid = np.zeros((self.max_objs, ), dtype=np.int64)
  482. bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
  483. if self.num_classes > 1:
  484. # each category corresponds to a set of track ids
  485. cls_tr_ids = np.zeros(
  486. (self.num_classes, output_h, output_w), dtype=np.int64)
  487. cls_id_map = np.full((output_h, output_w), -1, dtype=np.int64)
  488. gt_bbox = sample['gt_bbox']
  489. gt_class = sample['gt_class']
  490. gt_ide = sample['gt_ide']
  491. for k in range(len(gt_bbox)):
  492. cls_id = gt_class[k][0]
  493. bbox = gt_bbox[k]
  494. ide = gt_ide[k][0]
  495. bbox[[0, 2]] = bbox[[0, 2]] * output_w
  496. bbox[[1, 3]] = bbox[[1, 3]] * output_h
  497. bbox_amodal = copy.deepcopy(bbox)
  498. bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
  499. bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
  500. bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
  501. bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
  502. bbox[0] = np.clip(bbox[0], 0, output_w - 1)
  503. bbox[1] = np.clip(bbox[1], 0, output_h - 1)
  504. h = bbox[3]
  505. w = bbox[2]
  506. bbox_xy = copy.deepcopy(bbox)
  507. bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
  508. bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
  509. bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
  510. bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
  511. if h > 0 and w > 0:
  512. radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
  513. radius = max(0, int(radius))
  514. ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
  515. ct_int = ct.astype(np.int32)
  516. self.draw_truncate_gaussian(heatmap[cls_id], ct_int,
  517. radius, radius)
  518. bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
  519. bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
  520. index[k] = ct_int[1] * output_w + ct_int[0]
  521. center_offset[k] = ct - ct_int
  522. index_mask[k] = 1
  523. reid[k] = ide
  524. bbox_xys[k] = bbox_xy
  525. if self.num_classes > 1:
  526. cls_id_map[ct_int[1], ct_int[0]] = cls_id
  527. cls_tr_ids[cls_id][ct_int[1]][ct_int[0]] = ide - 1
  528. # track id start from 0
  529. sample['heatmap'] = heatmap
  530. sample['index'] = index
  531. sample['offset'] = center_offset
  532. sample['size'] = bbox_size
  533. sample['index_mask'] = index_mask
  534. sample['reid'] = reid
  535. if self.num_classes > 1:
  536. sample['cls_id_map'] = cls_id_map
  537. sample['cls_tr_ids'] = cls_tr_ids
  538. sample['bbox_xys'] = bbox_xys
  539. sample.pop('is_crowd', None)
  540. sample.pop('difficult', None)
  541. sample.pop('gt_class', None)
  542. sample.pop('gt_bbox', None)
  543. sample.pop('gt_score', None)
  544. sample.pop('gt_ide', None)
  545. return samples