mot_operators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from numbers import Integral
  22. import cv2
  23. import copy
  24. import numpy as np
  25. import random
  26. import math
  27. from .operators import BaseOperator, register_op
  28. from .batch_operators import Gt2TTFTarget
  29. from paddlex.ppdet.modeling.bbox_utils import bbox_iou_np_expand
  30. from paddlex.ppdet.core.workspace import serializable
  31. from paddlex.ppdet.utils.logger import setup_logger
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'RGBReverse', 'LetterBoxResize', 'MOTRandomAffine', 'Gt2JDETargetThres',
  35. 'Gt2JDETargetMax', 'Gt2FairMOTTarget'
  36. ]
  37. @register_op
  38. class RGBReverse(BaseOperator):
  39. """RGB to BGR, or BGR to RGB, sensitive to MOTRandomAffine
  40. """
  41. def __init__(self):
  42. super(RGBReverse, self).__init__()
  43. def apply(self, sample, context=None):
  44. im = sample['image']
  45. sample['image'] = np.ascontiguousarray(im[:, :, ::-1])
  46. return sample
  47. @register_op
  48. class LetterBoxResize(BaseOperator):
  49. def __init__(self, target_size):
  50. """
  51. Resize image to target size, convert normalized xywh to pixel xyxy
  52. format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
  53. Args:
  54. target_size (int|list): image target size.
  55. """
  56. super(LetterBoxResize, self).__init__()
  57. if not isinstance(target_size, (Integral, Sequence)):
  58. raise TypeError(
  59. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  60. format(type(target_size)))
  61. if isinstance(target_size, Integral):
  62. target_size = [target_size, target_size]
  63. self.target_size = target_size
  64. def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
  65. # letterbox: resize a rectangular image to a padded rectangular
  66. shape = img.shape[:2] # [height, width]
  67. ratio_h = float(height) / shape[0]
  68. ratio_w = float(width) / shape[1]
  69. ratio = min(ratio_h, ratio_w)
  70. new_shape = (round(shape[1] * ratio),
  71. round(shape[0] * ratio)) # [width, height]
  72. padw = (width - new_shape[0]) / 2
  73. padh = (height - new_shape[1]) / 2
  74. top, bottom = round(padh - 0.1), round(padh + 0.1)
  75. left, right = round(padw - 0.1), round(padw + 0.1)
  76. img = cv2.resize(
  77. img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
  78. img = cv2.copyMakeBorder(
  79. img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  80. value=color) # padded rectangular
  81. return img, ratio, padw, padh
  82. def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
  83. bboxes = bbox0.copy()
  84. bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
  85. bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
  86. bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
  87. bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
  88. return bboxes
  89. def apply(self, sample, context=None):
  90. """ Resize the image numpy.
  91. """
  92. im = sample['image']
  93. h, w = sample['im_shape']
  94. if not isinstance(im, np.ndarray):
  95. raise TypeError("{}: image type is not numpy.".format(self))
  96. if len(im.shape) != 3:
  97. raise ImageError('{}: image is not 3-dimensional.'.format(self))
  98. # apply image
  99. height, width = self.target_size
  100. img, ratio, padw, padh = self.apply_image(
  101. im, height=height, width=width)
  102. sample['image'] = img
  103. new_shape = (round(h * ratio), round(w * ratio))
  104. sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
  105. sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
  106. # apply bbox
  107. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  108. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
  109. padw, padh)
  110. return sample
  111. @register_op
  112. class MOTRandomAffine(BaseOperator):
  113. """
  114. Affine transform to image and coords to achieve the rotate, scale and
  115. shift effect for training image.
  116. Args:
  117. degrees (list[2]): the rotate range to apply, transform range is [min, max]
  118. translate (list[2]): the translate range to apply, ransform range is [min, max]
  119. scale (list[2]): the scale range to apply, transform range is [min, max]
  120. shear (list[2]): the shear range to apply, transform range is [min, max]
  121. borderValue (list[3]): value used in case of a constant border when appling
  122. the perspective transformation
  123. reject_outside (bool): reject warped bounding bboxes outside of image
  124. Returns:
  125. records(dict): contain the image and coords after tranformed
  126. """
  127. def __init__(self,
  128. degrees=(-5, 5),
  129. translate=(0.10, 0.10),
  130. scale=(0.50, 1.20),
  131. shear=(-2, 2),
  132. borderValue=(127.5, 127.5, 127.5),
  133. reject_outside=True):
  134. super(MOTRandomAffine, self).__init__()
  135. self.degrees = degrees
  136. self.translate = translate
  137. self.scale = scale
  138. self.shear = shear
  139. self.borderValue = borderValue
  140. self.reject_outside = reject_outside
  141. def apply(self, sample, context=None):
  142. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  143. border = 0 # width of added border (optional)
  144. img = sample['image']
  145. height, width = img.shape[0], img.shape[1]
  146. # Rotation and Scale
  147. R = np.eye(3)
  148. a = random.random() * (self.degrees[1] - self.degrees[0]
  149. ) + self.degrees[0]
  150. s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
  151. R[:2] = cv2.getRotationMatrix2D(
  152. angle=a, center=(width / 2, height / 2), scale=s)
  153. # Translation
  154. T = np.eye(3)
  155. T[0, 2] = (
  156. random.random() * 2 - 1
  157. ) * self.translate[0] * height + border # x translation (pixels)
  158. T[1, 2] = (
  159. random.random() * 2 - 1
  160. ) * self.translate[1] * width + border # y translation (pixels)
  161. # Shear
  162. S = np.eye(3)
  163. S[0, 1] = math.tan((random.random() *
  164. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  165. math.pi / 180) # x shear (deg)
  166. S[1, 0] = math.tan((random.random() *
  167. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  168. math.pi / 180) # y shear (deg)
  169. M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
  170. imw = cv2.warpPerspective(
  171. img,
  172. M,
  173. dsize=(width, height),
  174. flags=cv2.INTER_LINEAR,
  175. borderValue=self.borderValue) # BGR order borderValue
  176. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  177. targets = sample['gt_bbox']
  178. n = targets.shape[0]
  179. points = targets.copy()
  180. area0 = (points[:, 2] - points[:, 0]) * (
  181. points[:, 3] - points[:, 1])
  182. # warp points
  183. xy = np.ones((n * 4, 3))
  184. xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  185. n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  186. xy = (xy @M.T)[:, :2].reshape(n, 8)
  187. # create new boxes
  188. x = xy[:, [0, 2, 4, 6]]
  189. y = xy[:, [1, 3, 5, 7]]
  190. xy = np.concatenate(
  191. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  192. # apply angle-based reduction
  193. radians = a * math.pi / 180
  194. reduction = max(abs(math.sin(radians)),
  195. abs(math.cos(radians)))**0.5
  196. x = (xy[:, 2] + xy[:, 0]) / 2
  197. y = (xy[:, 3] + xy[:, 1]) / 2
  198. w = (xy[:, 2] - xy[:, 0]) * reduction
  199. h = (xy[:, 3] - xy[:, 1]) * reduction
  200. xy = np.concatenate(
  201. (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  202. # reject warped points outside of image
  203. if self.reject_outside:
  204. np.clip(xy[:, 0], 0, width, out=xy[:, 0])
  205. np.clip(xy[:, 2], 0, width, out=xy[:, 2])
  206. np.clip(xy[:, 1], 0, height, out=xy[:, 1])
  207. np.clip(xy[:, 3], 0, height, out=xy[:, 3])
  208. w = xy[:, 2] - xy[:, 0]
  209. h = xy[:, 3] - xy[:, 1]
  210. area = w * h
  211. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
  212. i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
  213. if sum(i) > 0:
  214. sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
  215. sample['gt_class'] = sample['gt_class'][i]
  216. if 'difficult' in sample:
  217. sample['difficult'] = sample['difficult'][i]
  218. if 'gt_ide' in sample:
  219. sample['gt_ide'] = sample['gt_ide'][i]
  220. if 'is_crowd' in sample:
  221. sample['is_crowd'] = sample['is_crowd'][i]
  222. sample['image'] = imw
  223. return sample
  224. else:
  225. return sample
  226. @register_op
  227. class Gt2JDETargetThres(BaseOperator):
  228. __shared__ = ['num_classes']
  229. """
  230. Generate JDE targets by groud truth data when training
  231. Args:
  232. anchors (list): anchors of JDE model
  233. anchor_masks (list): anchor_masks of JDE model
  234. downsample_ratios (list): downsample ratios of JDE model
  235. ide_thresh (float): thresh of identity, higher is groud truth
  236. fg_thresh (float): thresh of foreground, higher is foreground
  237. bg_thresh (float): thresh of background, lower is background
  238. num_classes (int): number of classes
  239. """
  240. def __init__(self,
  241. anchors,
  242. anchor_masks,
  243. downsample_ratios,
  244. ide_thresh=0.5,
  245. fg_thresh=0.5,
  246. bg_thresh=0.4,
  247. num_classes=1):
  248. super(Gt2JDETargetThres, self).__init__()
  249. self.anchors = anchors
  250. self.anchor_masks = anchor_masks
  251. self.downsample_ratios = downsample_ratios
  252. self.ide_thresh = ide_thresh
  253. self.fg_thresh = fg_thresh
  254. self.bg_thresh = bg_thresh
  255. self.num_classes = num_classes
  256. def generate_anchor(self, nGh, nGw, anchor_hw):
  257. nA = len(anchor_hw)
  258. yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
  259. mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
  260. mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
  261. anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
  262. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
  263. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
  264. anchor_mesh = np.concatenate(
  265. [mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
  266. return anchor_mesh
  267. def encode_delta(self, gt_box_list, fg_anchor_list):
  268. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  269. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  270. gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
  271. gt_box_list[:, 2], gt_box_list[:, 3]
  272. dx = (gx - px) / pw
  273. dy = (gy - py) / ph
  274. dw = np.log(gw / pw)
  275. dh = np.log(gh / ph)
  276. return np.stack([dx, dy, dw, dh], axis=1)
  277. def pad_box(self, sample, num_max):
  278. assert 'gt_bbox' in sample
  279. bbox = sample['gt_bbox']
  280. gt_num = len(bbox)
  281. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  282. if gt_num > 0:
  283. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  284. sample['gt_bbox'] = pad_bbox
  285. if 'gt_score' in sample:
  286. pad_score = np.zeros((num_max, ), dtype=np.float32)
  287. if gt_num > 0:
  288. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  289. sample['gt_score'] = pad_score
  290. if 'difficult' in sample:
  291. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  292. if gt_num > 0:
  293. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  294. sample['difficult'] = pad_diff
  295. if 'is_crowd' in sample:
  296. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  297. if gt_num > 0:
  298. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  299. sample['is_crowd'] = pad_crowd
  300. if 'gt_ide' in sample:
  301. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  302. if gt_num > 0:
  303. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  304. sample['gt_ide'] = pad_ide
  305. return sample
  306. def __call__(self, samples, context=None):
  307. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  308. "anchor_masks', and 'downsample_ratios' should have same length."
  309. h, w = samples[0]['image'].shape[1:3]
  310. num_max = 0
  311. for sample in samples:
  312. num_max = max(num_max, len(sample['gt_bbox']))
  313. for sample in samples:
  314. gt_bbox = sample['gt_bbox']
  315. gt_ide = sample['gt_ide']
  316. for i, (anchor_hw, downsample_ratio
  317. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  318. anchor_hw = np.array(
  319. anchor_hw, dtype=np.float32) / downsample_ratio
  320. nA = len(anchor_hw)
  321. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  322. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  323. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  324. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  325. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  326. gxy[:, 0] = gxy[:, 0] * nGw
  327. gxy[:, 1] = gxy[:, 1] * nGh
  328. gwh[:, 0] = gwh[:, 0] * nGw
  329. gwh[:, 1] = gwh[:, 1] * nGh
  330. gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
  331. gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
  332. tboxes = np.concatenate([gxy, gwh], axis=1)
  333. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
  334. anchor_list = np.transpose(anchor_mesh,
  335. (0, 2, 3, 1)).reshape(-1, 4)
  336. iou_pdist = bbox_iou_np_expand(
  337. anchor_list, tboxes, x1y1x2y2=False)
  338. iou_max = np.max(iou_pdist, axis=1)
  339. max_gt_index = np.argmax(iou_pdist, axis=1)
  340. iou_map = iou_max.reshape(nA, nGh, nGw)
  341. gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
  342. id_index = iou_map > self.ide_thresh
  343. fg_index = iou_map > self.fg_thresh
  344. bg_index = iou_map < self.bg_thresh
  345. ign_index = (iou_map < self.fg_thresh) * (
  346. iou_map > self.bg_thresh)
  347. tconf[fg_index] = 1
  348. tconf[bg_index] = 0
  349. tconf[ign_index] = -1
  350. gt_index = gt_index_map[fg_index]
  351. gt_box_list = tboxes[gt_index]
  352. gt_id_list = gt_ide[gt_index_map[id_index]]
  353. if np.sum(fg_index) > 0:
  354. tid[id_index] = gt_id_list
  355. fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
  356. 4)[fg_index]
  357. delta_target = self.encode_delta(gt_box_list,
  358. fg_anchor_list)
  359. tbox[fg_index] = delta_target
  360. sample['tbox{}'.format(i)] = tbox
  361. sample['tconf{}'.format(i)] = tconf
  362. sample['tide{}'.format(i)] = tid
  363. sample.pop('gt_class')
  364. sample = self.pad_box(sample, num_max)
  365. return samples
  366. @register_op
  367. class Gt2JDETargetMax(BaseOperator):
  368. __shared__ = ['num_classes']
  369. """
  370. Generate JDE targets by groud truth data when evaluating
  371. Args:
  372. anchors (list): anchors of JDE model
  373. anchor_masks (list): anchor_masks of JDE model
  374. downsample_ratios (list): downsample ratios of JDE model
  375. max_iou_thresh (float): iou thresh for high quality anchor
  376. num_classes (int): number of classes
  377. """
  378. def __init__(self,
  379. anchors,
  380. anchor_masks,
  381. downsample_ratios,
  382. max_iou_thresh=0.60,
  383. num_classes=1):
  384. super(Gt2JDETargetMax, self).__init__()
  385. self.anchors = anchors
  386. self.anchor_masks = anchor_masks
  387. self.downsample_ratios = downsample_ratios
  388. self.max_iou_thresh = max_iou_thresh
  389. self.num_classes = num_classes
  390. def __call__(self, samples, context=None):
  391. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  392. "anchor_masks', and 'downsample_ratios' should have same length."
  393. h, w = samples[0]['image'].shape[1:3]
  394. for sample in samples:
  395. gt_bbox = sample['gt_bbox']
  396. gt_ide = sample['gt_ide']
  397. for i, (anchor_hw, downsample_ratio
  398. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  399. anchor_hw = np.array(
  400. anchor_hw, dtype=np.float32) / downsample_ratio
  401. nA = len(anchor_hw)
  402. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  403. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  404. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  405. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  406. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  407. gxy[:, 0] = gxy[:, 0] * nGw
  408. gxy[:, 1] = gxy[:, 1] * nGh
  409. gwh[:, 0] = gwh[:, 0] * nGw
  410. gwh[:, 1] = gwh[:, 1] * nGh
  411. gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
  412. gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
  413. # iou of targets-anchors (using wh only)
  414. box1 = gwh
  415. box2 = anchor_hw[:, None, :]
  416. inter_area = np.minimum(box1, box2).prod(2)
  417. iou = inter_area / (
  418. box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
  419. # Select best iou_pred and anchor
  420. iou_best = iou.max(0) # best anchor [0-2] for each target
  421. a = np.argmax(iou, axis=0)
  422. # Select best unique target-anchor combinations
  423. iou_order = np.argsort(-iou_best) # best to worst
  424. # Unique anchor selection
  425. u = np.stack((gi, gj, a), 0)[:, iou_order]
  426. _, first_unique = np.unique(u, axis=1, return_index=True)
  427. mask = iou_order[first_unique]
  428. # best anchor must share significant commonality (iou) with target
  429. # TODO: examine arbitrary threshold
  430. idx = mask[iou_best[mask] > self.max_iou_thresh]
  431. if len(idx) > 0:
  432. a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
  433. t_box = gt_bbox[idx]
  434. t_id = gt_ide[idx]
  435. if len(t_box.shape) == 1:
  436. t_box = t_box.reshape(1, 4)
  437. gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
  438. gxy[:, 0] = gxy[:, 0] * nGw
  439. gxy[:, 1] = gxy[:, 1] * nGh
  440. gwh[:, 0] = gwh[:, 0] * nGw
  441. gwh[:, 1] = gwh[:, 1] * nGh
  442. # XY coordinates
  443. tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
  444. # Width and height in yolo method
  445. tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(
  446. gwh / anchor_hw[a_i])
  447. tconf[a_i, gj_i, gi_i] = 1
  448. tid[a_i, gj_i, gi_i] = t_id
  449. sample['tbox{}'.format(i)] = tbox
  450. sample['tconf{}'.format(i)] = tconf
  451. sample['tide{}'.format(i)] = tid
  452. class Gt2FairMOTTarget(Gt2TTFTarget):
  453. __shared__ = ['num_classes']
  454. """
  455. Generate FairMOT targets by ground truth data.
  456. Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
  457. 1. the gaussian kernal radius to generate a heatmap.
  458. 2. the targets needed during traing.
  459. Args:
  460. num_classes(int): the number of classes.
  461. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  462. max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
  463. """
  464. def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
  465. super(Gt2TTFTarget, self).__init__()
  466. self.down_ratio = down_ratio
  467. self.num_classes = num_classes
  468. self.max_objs = max_objs
  469. def __call__(self, samples, context=None):
  470. for b_id, sample in enumerate(samples):
  471. output_h = sample['image'].shape[1] // self.down_ratio
  472. output_w = sample['image'].shape[2] // self.down_ratio
  473. heatmap = np.zeros(
  474. (self.num_classes, output_h, output_w), dtype='float32')
  475. bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
  476. center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
  477. index = np.zeros((self.max_objs, ), dtype=np.int64)
  478. index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
  479. reid = np.zeros((self.max_objs, ), dtype=np.int64)
  480. bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
  481. gt_bbox = sample['gt_bbox']
  482. gt_class = sample['gt_class']
  483. gt_ide = sample['gt_ide']
  484. for k in range(len(gt_bbox)):
  485. cls_id = gt_class[k][0]
  486. bbox = gt_bbox[k]
  487. ide = gt_ide[k][0]
  488. bbox[[0, 2]] = bbox[[0, 2]] * output_w
  489. bbox[[1, 3]] = bbox[[1, 3]] * output_h
  490. bbox_amodal = copy.deepcopy(bbox)
  491. bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
  492. bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
  493. bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
  494. bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
  495. bbox[0] = np.clip(bbox[0], 0, output_w - 1)
  496. bbox[1] = np.clip(bbox[1], 0, output_h - 1)
  497. h = bbox[3]
  498. w = bbox[2]
  499. bbox_xy = copy.deepcopy(bbox)
  500. bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
  501. bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
  502. bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
  503. bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
  504. if h > 0 and w > 0:
  505. radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
  506. radius = max(0, int(radius))
  507. ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
  508. ct_int = ct.astype(np.int32)
  509. self.draw_truncate_gaussian(heatmap[cls_id], ct_int,
  510. radius, radius)
  511. bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
  512. bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
  513. index[k] = ct_int[1] * output_w + ct_int[0]
  514. center_offset[k] = ct - ct_int
  515. index_mask[k] = 1
  516. reid[k] = ide
  517. bbox_xys[k] = bbox_xy
  518. sample['heatmap'] = heatmap
  519. sample['index'] = index
  520. sample['offset'] = center_offset
  521. sample['size'] = bbox_size
  522. sample['index_mask'] = index_mask
  523. sample['reid'] = reid
  524. sample['bbox_xys'] = bbox_xys
  525. sample.pop('is_crowd', None)
  526. sample.pop('difficult', None)
  527. sample.pop('gt_class', None)
  528. sample.pop('gt_bbox', None)
  529. sample.pop('gt_score', None)
  530. sample.pop('gt_ide', None)
  531. return samples
  532. def gaussian_radius(self, det_size, min_overlap=0.7):
  533. height, width = det_size
  534. a1 = 1
  535. b1 = (height + width)
  536. c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
  537. sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
  538. r1 = (b1 + sq1) / 2
  539. a2 = 4
  540. b2 = 2 * (height + width)
  541. c2 = (1 - min_overlap) * width * height
  542. sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
  543. r2 = (b2 + sq2) / 2
  544. a3 = 4 * min_overlap
  545. b3 = -2 * min_overlap * (height + width)
  546. c3 = (min_overlap - 1) * width * height
  547. sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
  548. r3 = (b3 + sq3) / 2
  549. return min(r1, r2, r3)