batch_operators.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import typing
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. import cv2
  23. import math
  24. import numpy as np
  25. from .operators import register_op, BaseOperator, Resize
  26. from .op_helper import jaccard_overlap, gaussian2D, gaussian_radius, draw_umich_gaussian
  27. from .atss_assigner import ATSSAssigner
  28. from scipy import ndimage
  29. from paddlex.ppdet.modeling import bbox_utils
  30. from paddlex.ppdet.utils.logger import setup_logger
  31. from paddlex.ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'PadBatch',
  35. 'BatchRandomResize',
  36. 'Gt2YoloTarget',
  37. 'Gt2FCOSTarget',
  38. 'Gt2TTFTarget',
  39. 'Gt2Solov2Target',
  40. 'Gt2SparseRCNNTarget',
  41. 'PadMaskBatch',
  42. 'Gt2GFLTarget',
  43. 'Gt2CenterNetTarget',
  44. 'PadGT',
  45. ]
  46. @register_op
  47. class PadBatch(BaseOperator):
  48. """
  49. Pad a batch of samples so they can be divisible by a stride.
  50. The layout of each image should be 'CHW'.
  51. Args:
  52. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  53. height and width is divisible by `pad_to_stride`.
  54. """
  55. def __init__(self, pad_to_stride=0):
  56. super(PadBatch, self).__init__()
  57. self.pad_to_stride = pad_to_stride
  58. def __call__(self, samples, context=None):
  59. """
  60. Args:
  61. samples (list): a batch of sample, each is dict.
  62. """
  63. coarsest_stride = self.pad_to_stride
  64. # multi scale input is nested list
  65. if isinstance(samples,
  66. typing.Sequence) and len(samples) > 0 and isinstance(
  67. samples[0], typing.Sequence):
  68. inner_samples = samples[0]
  69. else:
  70. inner_samples = samples
  71. max_shape = np.array(
  72. [data['image'].shape for data in inner_samples]).max(axis=0)
  73. if coarsest_stride > 0:
  74. max_shape[1] = int(
  75. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  76. max_shape[2] = int(
  77. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  78. for data in inner_samples:
  79. im = data['image']
  80. im_c, im_h, im_w = im.shape[:]
  81. padding_im = np.zeros(
  82. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  83. padding_im[:, :im_h, :im_w] = im
  84. data['image'] = padding_im
  85. if 'semantic' in data and data['semantic'] is not None:
  86. semantic = data['semantic']
  87. padding_sem = np.zeros(
  88. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  89. padding_sem[:, :im_h, :im_w] = semantic
  90. data['semantic'] = padding_sem
  91. if 'gt_segm' in data and data['gt_segm'] is not None:
  92. gt_segm = data['gt_segm']
  93. padding_segm = np.zeros(
  94. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  95. dtype=np.uint8)
  96. padding_segm[:, :im_h, :im_w] = gt_segm
  97. data['gt_segm'] = padding_segm
  98. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  99. # ploy to rbox
  100. polys = data['gt_rbox2poly']
  101. rbox = bbox_utils.poly2rbox(polys)
  102. data['gt_rbox'] = rbox
  103. return samples
  104. @register_op
  105. class BatchRandomResize(BaseOperator):
  106. """
  107. Resize image to target size randomly. random target_size and interpolation method
  108. Args:
  109. target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
  110. keep_ratio (bool): whether keep_raio or not, default true
  111. interp (int): the interpolation method
  112. random_size (bool): whether random select target size of image
  113. random_interp (bool): whether random select interpolation method
  114. """
  115. def __init__(self,
  116. target_size,
  117. keep_ratio,
  118. interp=cv2.INTER_NEAREST,
  119. random_size=True,
  120. random_interp=False):
  121. super(BatchRandomResize, self).__init__()
  122. self.keep_ratio = keep_ratio
  123. self.interps = [
  124. cv2.INTER_NEAREST,
  125. cv2.INTER_LINEAR,
  126. cv2.INTER_AREA,
  127. cv2.INTER_CUBIC,
  128. cv2.INTER_LANCZOS4,
  129. ]
  130. self.interp = interp
  131. assert isinstance(target_size, (
  132. int, Sequence)), "target_size must be int, list or tuple"
  133. if random_size and not isinstance(target_size, list):
  134. raise TypeError(
  135. "Type of target_size is invalid when random_size is True. Must be List, now is {}".
  136. format(type(target_size)))
  137. self.target_size = target_size
  138. self.random_size = random_size
  139. self.random_interp = random_interp
  140. def __call__(self, samples, context=None):
  141. if self.random_size:
  142. index = np.random.choice(len(self.target_size))
  143. target_size = self.target_size[index]
  144. else:
  145. target_size = self.target_size
  146. if self.random_interp:
  147. interp = np.random.choice(self.interps)
  148. else:
  149. interp = self.interp
  150. resizer = Resize(
  151. target_size, keep_ratio=self.keep_ratio, interp=interp)
  152. return resizer(samples, context=context)
  153. @register_op
  154. class Gt2YoloTarget(BaseOperator):
  155. """
  156. Generate YOLOv3 targets by groud truth data, this operator is only used in
  157. fine grained YOLOv3 loss mode
  158. """
  159. def __init__(self,
  160. anchors,
  161. anchor_masks,
  162. downsample_ratios,
  163. num_classes=80,
  164. iou_thresh=1.):
  165. super(Gt2YoloTarget, self).__init__()
  166. self.anchors = anchors
  167. self.anchor_masks = anchor_masks
  168. self.downsample_ratios = downsample_ratios
  169. self.num_classes = num_classes
  170. self.iou_thresh = iou_thresh
  171. def __call__(self, samples, context=None):
  172. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  173. "anchor_masks', and 'downsample_ratios' should have same length."
  174. h, w = samples[0]['image'].shape[1:3]
  175. an_hw = np.array(self.anchors) / np.array([[w, h]])
  176. for sample in samples:
  177. gt_bbox = sample['gt_bbox']
  178. gt_class = sample['gt_class']
  179. if 'gt_score' not in sample:
  180. sample['gt_score'] = np.ones(
  181. (gt_bbox.shape[0], 1), dtype=np.float32)
  182. gt_score = sample['gt_score']
  183. for i, (
  184. mask, downsample_ratio
  185. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  186. grid_h = int(h / downsample_ratio)
  187. grid_w = int(w / downsample_ratio)
  188. target = np.zeros(
  189. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  190. dtype=np.float32)
  191. for b in range(gt_bbox.shape[0]):
  192. gx, gy, gw, gh = gt_bbox[b, :]
  193. cls = gt_class[b]
  194. score = gt_score[b]
  195. if gw <= 0. or gh <= 0. or score <= 0.:
  196. continue
  197. # find best match anchor index
  198. best_iou = 0.
  199. best_idx = -1
  200. for an_idx in range(an_hw.shape[0]):
  201. iou = jaccard_overlap(
  202. [0., 0., gw, gh],
  203. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  204. if iou > best_iou:
  205. best_iou = iou
  206. best_idx = an_idx
  207. gi = int(gx * grid_w)
  208. gj = int(gy * grid_h)
  209. # gtbox should be regresed in this layes if best match
  210. # anchor index in anchor mask of this layer
  211. if best_idx in mask:
  212. best_n = mask.index(best_idx)
  213. # x, y, w, h, scale
  214. target[best_n, 0, gj, gi] = gx * grid_w - gi
  215. target[best_n, 1, gj, gi] = gy * grid_h - gj
  216. target[best_n, 2, gj, gi] = np.log(
  217. gw * w / self.anchors[best_idx][0])
  218. target[best_n, 3, gj, gi] = np.log(
  219. gh * h / self.anchors[best_idx][1])
  220. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  221. # objectness record gt_score
  222. target[best_n, 5, gj, gi] = score
  223. # classification
  224. target[best_n, 6 + cls, gj, gi] = 1.
  225. # For non-matched anchors, calculate the target if the iou
  226. # between anchor and gt is larger than iou_thresh
  227. if self.iou_thresh < 1:
  228. for idx, mask_i in enumerate(mask):
  229. if mask_i == best_idx: continue
  230. iou = jaccard_overlap(
  231. [0., 0., gw, gh],
  232. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  233. if iou > self.iou_thresh and target[idx, 5, gj,
  234. gi] == 0.:
  235. # x, y, w, h, scale
  236. target[idx, 0, gj, gi] = gx * grid_w - gi
  237. target[idx, 1, gj, gi] = gy * grid_h - gj
  238. target[idx, 2, gj, gi] = np.log(
  239. gw * w / self.anchors[mask_i][0])
  240. target[idx, 3, gj, gi] = np.log(
  241. gh * h / self.anchors[mask_i][1])
  242. target[idx, 4, gj, gi] = 2.0 - gw * gh
  243. # objectness record gt_score
  244. target[idx, 5, gj, gi] = score
  245. # classification
  246. target[idx, 6 + cls, gj, gi] = 1.
  247. sample['target{}'.format(i)] = target
  248. # remove useless gt_class and gt_score after target calculated
  249. sample.pop('gt_class')
  250. sample.pop('gt_score')
  251. return samples
  252. @register_op
  253. class Gt2FCOSTarget(BaseOperator):
  254. """
  255. Generate FCOS targets by groud truth data
  256. """
  257. def __init__(self,
  258. object_sizes_boundary,
  259. center_sampling_radius,
  260. downsample_ratios,
  261. norm_reg_targets=False):
  262. super(Gt2FCOSTarget, self).__init__()
  263. self.center_sampling_radius = center_sampling_radius
  264. self.downsample_ratios = downsample_ratios
  265. self.INF = np.inf
  266. self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
  267. object_sizes_of_interest = []
  268. for i in range(len(self.object_sizes_boundary) - 1):
  269. object_sizes_of_interest.append([
  270. self.object_sizes_boundary[i],
  271. self.object_sizes_boundary[i + 1]
  272. ])
  273. self.object_sizes_of_interest = object_sizes_of_interest
  274. self.norm_reg_targets = norm_reg_targets
  275. def _compute_points(self, w, h):
  276. """
  277. compute the corresponding points in each feature map
  278. :param h: image height
  279. :param w: image width
  280. :return: points from all feature map
  281. """
  282. locations = []
  283. for stride in self.downsample_ratios:
  284. shift_x = np.arange(0, w, stride).astype(np.float32)
  285. shift_y = np.arange(0, h, stride).astype(np.float32)
  286. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  287. shift_x = shift_x.flatten()
  288. shift_y = shift_y.flatten()
  289. location = np.stack([shift_x, shift_y], axis=1) + stride // 2
  290. locations.append(location)
  291. num_points_each_level = [len(location) for location in locations]
  292. locations = np.concatenate(locations, axis=0)
  293. return locations, num_points_each_level
  294. def _convert_xywh2xyxy(self, gt_bbox, w, h):
  295. """
  296. convert the bounding box from style xywh to xyxy
  297. :param gt_bbox: bounding boxes normalized into [0, 1]
  298. :param w: image width
  299. :param h: image height
  300. :return: bounding boxes in xyxy style
  301. """
  302. bboxes = gt_bbox.copy()
  303. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
  304. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
  305. bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
  306. bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
  307. return bboxes
  308. def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
  309. num_points_each_level):
  310. """
  311. check if points is within the clipped boxes
  312. :param gt_bbox: bounding boxes
  313. :param xs: horizontal coordinate of points
  314. :param ys: vertical coordinate of points
  315. :return: the mask of points is within gt_box or not
  316. """
  317. bboxes = np.reshape(
  318. gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
  319. bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
  320. ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
  321. ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
  322. beg = 0
  323. clipped_box = bboxes.copy()
  324. for lvl, stride in enumerate(self.downsample_ratios):
  325. end = beg + num_points_each_level[lvl]
  326. stride_exp = self.center_sampling_radius * stride
  327. clipped_box[beg:end, :, 0] = np.maximum(
  328. bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
  329. clipped_box[beg:end, :, 1] = np.maximum(
  330. bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
  331. clipped_box[beg:end, :, 2] = np.minimum(
  332. bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
  333. clipped_box[beg:end, :, 3] = np.minimum(
  334. bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
  335. beg = end
  336. l_res = xs - clipped_box[:, :, 0]
  337. r_res = clipped_box[:, :, 2] - xs
  338. t_res = ys - clipped_box[:, :, 1]
  339. b_res = clipped_box[:, :, 3] - ys
  340. clipped_box_reg_targets = np.stack(
  341. [l_res, t_res, r_res, b_res], axis=2)
  342. inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
  343. return inside_gt_box
  344. def __call__(self, samples, context=None):
  345. assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
  346. "object_sizes_of_interest', and 'downsample_ratios' should have same length."
  347. for sample in samples:
  348. im = sample['image']
  349. bboxes = sample['gt_bbox']
  350. gt_class = sample['gt_class']
  351. # calculate the locations
  352. h, w = im.shape[1:3]
  353. points, num_points_each_level = self._compute_points(w, h)
  354. object_scale_exp = []
  355. for i, num_pts in enumerate(num_points_each_level):
  356. object_scale_exp.append(
  357. np.tile(
  358. np.array([self.object_sizes_of_interest[i]]),
  359. reps=[num_pts, 1]))
  360. object_scale_exp = np.concatenate(object_scale_exp, axis=0)
  361. gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
  362. bboxes[:, 3] - bboxes[:, 1])
  363. xs, ys = points[:, 0], points[:, 1]
  364. xs = np.reshape(xs, newshape=[xs.shape[0], 1])
  365. xs = np.tile(xs, reps=[1, bboxes.shape[0]])
  366. ys = np.reshape(ys, newshape=[ys.shape[0], 1])
  367. ys = np.tile(ys, reps=[1, bboxes.shape[0]])
  368. l_res = xs - bboxes[:, 0]
  369. r_res = bboxes[:, 2] - xs
  370. t_res = ys - bboxes[:, 1]
  371. b_res = bboxes[:, 3] - ys
  372. reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
  373. if self.center_sampling_radius > 0:
  374. is_inside_box = self._check_inside_boxes_limited(
  375. bboxes, xs, ys, num_points_each_level)
  376. else:
  377. is_inside_box = np.min(reg_targets, axis=2) > 0
  378. # check if the targets is inside the corresponding level
  379. max_reg_targets = np.max(reg_targets, axis=2)
  380. lower_bound = np.tile(
  381. np.expand_dims(
  382. object_scale_exp[:, 0], axis=1),
  383. reps=[1, max_reg_targets.shape[1]])
  384. high_bound = np.tile(
  385. np.expand_dims(
  386. object_scale_exp[:, 1], axis=1),
  387. reps=[1, max_reg_targets.shape[1]])
  388. is_match_current_level = \
  389. (max_reg_targets > lower_bound) & \
  390. (max_reg_targets < high_bound)
  391. points2gtarea = np.tile(
  392. np.expand_dims(
  393. gt_area, axis=0), reps=[xs.shape[0], 1])
  394. points2gtarea[is_inside_box == 0] = self.INF
  395. points2gtarea[is_match_current_level == 0] = self.INF
  396. points2min_area = points2gtarea.min(axis=1)
  397. points2min_area_ind = points2gtarea.argmin(axis=1)
  398. labels = gt_class[points2min_area_ind] + 1
  399. labels[points2min_area == self.INF] = 0
  400. reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
  401. ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
  402. reg_targets[:, [0, 2]].max(axis=1)) * \
  403. (reg_targets[:, [1, 3]].min(axis=1) / \
  404. reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
  405. ctn_targets = np.reshape(
  406. ctn_targets, newshape=[ctn_targets.shape[0], 1])
  407. ctn_targets[labels <= 0] = 0
  408. pos_ind = np.nonzero(labels != 0)
  409. reg_targets_pos = reg_targets[pos_ind[0], :]
  410. split_sections = []
  411. beg = 0
  412. for lvl in range(len(num_points_each_level)):
  413. end = beg + num_points_each_level[lvl]
  414. split_sections.append(end)
  415. beg = end
  416. labels_by_level = np.split(labels, split_sections, axis=0)
  417. reg_targets_by_level = np.split(
  418. reg_targets, split_sections, axis=0)
  419. ctn_targets_by_level = np.split(
  420. ctn_targets, split_sections, axis=0)
  421. for lvl in range(len(self.downsample_ratios)):
  422. grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
  423. grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
  424. if self.norm_reg_targets:
  425. sample['reg_target{}'.format(lvl)] = \
  426. np.reshape(
  427. reg_targets_by_level[lvl] / \
  428. self.downsample_ratios[lvl],
  429. newshape=[grid_h, grid_w, 4])
  430. else:
  431. sample['reg_target{}'.format(lvl)] = np.reshape(
  432. reg_targets_by_level[lvl],
  433. newshape=[grid_h, grid_w, 4])
  434. sample['labels{}'.format(lvl)] = np.reshape(
  435. labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
  436. sample['centerness{}'.format(lvl)] = np.reshape(
  437. ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
  438. sample.pop('is_crowd', None)
  439. sample.pop('difficult', None)
  440. sample.pop('gt_class', None)
  441. sample.pop('gt_bbox', None)
  442. return samples
  443. @register_op
  444. class Gt2GFLTarget(BaseOperator):
  445. """
  446. Generate GFocal loss targets by groud truth data
  447. """
  448. def __init__(self,
  449. num_classes=80,
  450. downsample_ratios=[8, 16, 32, 64, 128],
  451. grid_cell_scale=4,
  452. cell_offset=0):
  453. super(Gt2GFLTarget, self).__init__()
  454. self.num_classes = num_classes
  455. self.downsample_ratios = downsample_ratios
  456. self.grid_cell_scale = grid_cell_scale
  457. self.cell_offset = cell_offset
  458. self.assigner = ATSSAssigner()
  459. def get_grid_cells(self, featmap_size, scale, stride, offset=0):
  460. """
  461. Generate grid cells of a feature map for target assignment.
  462. Args:
  463. featmap_size: Size of a single level feature map.
  464. scale: Grid cell scale.
  465. stride: Down sample stride of the feature map.
  466. offset: Offset of grid cells.
  467. return:
  468. Grid_cells xyxy position. Size should be [feat_w * feat_h, 4]
  469. """
  470. cell_size = stride * scale
  471. h, w = featmap_size
  472. x_range = (np.arange(w, dtype=np.float32) + offset) * stride
  473. y_range = (np.arange(h, dtype=np.float32) + offset) * stride
  474. x, y = np.meshgrid(x_range, y_range)
  475. y = y.flatten()
  476. x = x.flatten()
  477. grid_cells = np.stack(
  478. [
  479. x - 0.5 * cell_size, y - 0.5 * cell_size, x + 0.5 * cell_size,
  480. y + 0.5 * cell_size
  481. ],
  482. axis=-1)
  483. return grid_cells
  484. def get_sample(self, assign_gt_inds, gt_bboxes):
  485. pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
  486. neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
  487. pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1
  488. if gt_bboxes.size == 0:
  489. # hack for index error case
  490. assert pos_assigned_gt_inds.size == 0
  491. pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
  492. else:
  493. if len(gt_bboxes.shape) < 2:
  494. gt_bboxes = gt_bboxes.resize(-1, 4)
  495. pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
  496. return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds
  497. def __call__(self, samples, context=None):
  498. assert len(samples) > 0
  499. batch_size = len(samples)
  500. # get grid cells of image
  501. h, w = samples[0]['image'].shape[1:3]
  502. multi_level_grid_cells = []
  503. for stride in self.downsample_ratios:
  504. featmap_size = (int(math.ceil(h / stride)),
  505. int(math.ceil(w / stride)))
  506. multi_level_grid_cells.append(
  507. self.get_grid_cells(featmap_size, self.grid_cell_scale, stride,
  508. self.cell_offset))
  509. mlvl_grid_cells_list = [
  510. multi_level_grid_cells for i in range(batch_size)
  511. ]
  512. # pixel cell number of multi-level feature maps
  513. num_level_cells = [
  514. grid_cells.shape[0] for grid_cells in mlvl_grid_cells_list[0]
  515. ]
  516. num_level_cells_list = [num_level_cells] * batch_size
  517. # concat all level cells and to a single array
  518. for i in range(batch_size):
  519. mlvl_grid_cells_list[i] = np.concatenate(mlvl_grid_cells_list[i])
  520. # target assign on all images
  521. for sample, grid_cells, num_level_cells in zip(
  522. samples, mlvl_grid_cells_list, num_level_cells_list):
  523. gt_bboxes = sample['gt_bbox']
  524. gt_labels = sample['gt_class'].squeeze()
  525. if gt_labels.size == 1:
  526. gt_labels = np.array([gt_labels]).astype(np.int32)
  527. gt_bboxes_ignore = None
  528. assign_gt_inds, _ = self.assigner(grid_cells, num_level_cells,
  529. gt_bboxes, gt_bboxes_ignore,
  530. gt_labels)
  531. pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample(
  532. assign_gt_inds, gt_bboxes)
  533. num_cells = grid_cells.shape[0]
  534. bbox_targets = np.zeros_like(grid_cells)
  535. bbox_weights = np.zeros_like(grid_cells)
  536. labels = np.ones([num_cells], dtype=np.int64) * self.num_classes
  537. label_weights = np.zeros([num_cells], dtype=np.float32)
  538. if len(pos_inds) > 0:
  539. pos_bbox_targets = pos_gt_bboxes
  540. bbox_targets[pos_inds, :] = pos_bbox_targets
  541. bbox_weights[pos_inds, :] = 1.0
  542. if not np.any(gt_labels):
  543. labels[pos_inds] = 0
  544. else:
  545. labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
  546. label_weights[pos_inds] = 1.0
  547. if len(neg_inds) > 0:
  548. label_weights[neg_inds] = 1.0
  549. sample['grid_cells'] = grid_cells
  550. sample['labels'] = labels
  551. sample['label_weights'] = label_weights
  552. sample['bbox_targets'] = bbox_targets
  553. sample['pos_num'] = max(pos_inds.size, 1)
  554. sample.pop('is_crowd', None)
  555. sample.pop('difficult', None)
  556. sample.pop('gt_class', None)
  557. sample.pop('gt_bbox', None)
  558. sample.pop('gt_score', None)
  559. return samples
  560. @register_op
  561. class Gt2TTFTarget(BaseOperator):
  562. __shared__ = ['num_classes']
  563. """
  564. Gt2TTFTarget
  565. Generate TTFNet targets by ground truth data
  566. Args:
  567. num_classes(int): the number of classes.
  568. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  569. alpha(float): the alpha parameter to generate gaussian target.
  570. 0.54 by default.
  571. """
  572. def __init__(self, num_classes=80, down_ratio=4, alpha=0.54):
  573. super(Gt2TTFTarget, self).__init__()
  574. self.down_ratio = down_ratio
  575. self.num_classes = num_classes
  576. self.alpha = alpha
  577. def __call__(self, samples, context=None):
  578. output_size = samples[0]['image'].shape[1]
  579. feat_size = output_size // self.down_ratio
  580. for sample in samples:
  581. heatmap = np.zeros(
  582. (self.num_classes, feat_size, feat_size), dtype='float32')
  583. box_target = np.ones(
  584. (4, feat_size, feat_size), dtype='float32') * -1
  585. reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
  586. gt_bbox = sample['gt_bbox']
  587. gt_class = sample['gt_class']
  588. bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
  589. bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
  590. area = bbox_w * bbox_h
  591. boxes_areas_log = np.log(area)
  592. boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
  593. boxes_area_topk_log = boxes_areas_log[boxes_ind]
  594. gt_bbox = gt_bbox[boxes_ind]
  595. gt_class = gt_class[boxes_ind]
  596. feat_gt_bbox = gt_bbox / self.down_ratio
  597. feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
  598. feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
  599. feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
  600. ct_inds = np.stack(
  601. [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
  602. (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
  603. axis=1) / self.down_ratio
  604. h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
  605. w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
  606. for k in range(len(gt_bbox)):
  607. cls_id = gt_class[k]
  608. fake_heatmap = np.zeros(
  609. (feat_size, feat_size), dtype='float32')
  610. self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
  611. h_radiuses_alpha[k],
  612. w_radiuses_alpha[k])
  613. heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
  614. box_target_inds = fake_heatmap > 0
  615. box_target[:, box_target_inds] = gt_bbox[k][:, None]
  616. local_heatmap = fake_heatmap[box_target_inds]
  617. ct_div = np.sum(local_heatmap)
  618. local_heatmap *= boxes_area_topk_log[k]
  619. reg_weight[0, box_target_inds] = local_heatmap / ct_div
  620. sample['ttf_heatmap'] = heatmap
  621. sample['ttf_box_target'] = box_target
  622. sample['ttf_reg_weight'] = reg_weight
  623. sample.pop('is_crowd', None)
  624. sample.pop('difficult', None)
  625. sample.pop('gt_class', None)
  626. sample.pop('gt_bbox', None)
  627. sample.pop('gt_score', None)
  628. return samples
  629. def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
  630. h, w = 2 * h_radius + 1, 2 * w_radius + 1
  631. sigma_x = w / 6
  632. sigma_y = h / 6
  633. gaussian = gaussian2D((h, w), sigma_x, sigma_y)
  634. x, y = int(center[0]), int(center[1])
  635. height, width = heatmap.shape[0:2]
  636. left, right = min(x, w_radius), min(width - x, w_radius + 1)
  637. top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
  638. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  639. masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
  640. left:w_radius + right]
  641. if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
  642. heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
  643. masked_heatmap, masked_gaussian)
  644. return heatmap
  645. @register_op
  646. class Gt2Solov2Target(BaseOperator):
  647. """Assign mask target and labels in SOLOv2 network.
  648. The code of this function is based on:
  649. https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py#L271
  650. Args:
  651. num_grids (list): The list of feature map grids size.
  652. scale_ranges (list): The list of mask boundary range.
  653. coord_sigma (float): The coefficient of coordinate area length.
  654. sampling_ratio (float): The ratio of down sampling.
  655. """
  656. def __init__(self,
  657. num_grids=[40, 36, 24, 16, 12],
  658. scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
  659. [384, 2048]],
  660. coord_sigma=0.2,
  661. sampling_ratio=4.0):
  662. super(Gt2Solov2Target, self).__init__()
  663. self.num_grids = num_grids
  664. self.scale_ranges = scale_ranges
  665. self.coord_sigma = coord_sigma
  666. self.sampling_ratio = sampling_ratio
  667. def _scale_size(self, im, scale):
  668. h, w = im.shape[:2]
  669. new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
  670. resized_img = cv2.resize(
  671. im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
  672. return resized_img
  673. def __call__(self, samples, context=None):
  674. sample_id = 0
  675. max_ins_num = [0] * len(self.num_grids)
  676. for sample in samples:
  677. gt_bboxes_raw = sample['gt_bbox']
  678. gt_labels_raw = sample['gt_class'] + 1
  679. im_c, im_h, im_w = sample['image'].shape[:]
  680. gt_masks_raw = sample['gt_segm'].astype(np.uint8)
  681. mask_feat_size = [
  682. int(im_h / self.sampling_ratio),
  683. int(im_w / self.sampling_ratio)
  684. ]
  685. gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
  686. (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
  687. ins_ind_label_list = []
  688. idx = 0
  689. for (lower_bound, upper_bound), num_grid \
  690. in zip(self.scale_ranges, self.num_grids):
  691. hit_indices = ((gt_areas >= lower_bound) &
  692. (gt_areas <= upper_bound)).nonzero()[0]
  693. num_ins = len(hit_indices)
  694. ins_label = []
  695. grid_order = []
  696. cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
  697. ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
  698. if num_ins == 0:
  699. ins_label = np.zeros(
  700. [1, mask_feat_size[0], mask_feat_size[1]],
  701. dtype=np.uint8)
  702. ins_ind_label_list.append(ins_ind_label)
  703. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  704. sample['ins_label{}'.format(idx)] = ins_label
  705. sample['grid_order{}'.format(idx)] = np.asarray(
  706. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  707. idx += 1
  708. continue
  709. gt_bboxes = gt_bboxes_raw[hit_indices]
  710. gt_labels = gt_labels_raw[hit_indices]
  711. gt_masks = gt_masks_raw[hit_indices, ...]
  712. half_ws = 0.5 * (
  713. gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
  714. half_hs = 0.5 * (
  715. gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
  716. for seg_mask, gt_label, half_h, half_w in zip(
  717. gt_masks, gt_labels, half_hs, half_ws):
  718. if seg_mask.sum() == 0:
  719. continue
  720. # mass center
  721. upsampled_size = (mask_feat_size[0] * 4,
  722. mask_feat_size[1] * 4)
  723. center_h, center_w = ndimage.measurements.center_of_mass(
  724. seg_mask)
  725. coord_w = int(
  726. (center_w / upsampled_size[1]) // (1. / num_grid))
  727. coord_h = int(
  728. (center_h / upsampled_size[0]) // (1. / num_grid))
  729. # left, top, right, down
  730. top_box = max(0,
  731. int(((center_h - half_h) / upsampled_size[0])
  732. // (1. / num_grid)))
  733. down_box = min(
  734. num_grid - 1,
  735. int(((center_h + half_h) / upsampled_size[0]) //
  736. (1. / num_grid)))
  737. left_box = max(
  738. 0,
  739. int(((center_w - half_w) / upsampled_size[1]) //
  740. (1. / num_grid)))
  741. right_box = min(num_grid - 1,
  742. int(((center_w + half_w) /
  743. upsampled_size[1]) //
  744. (1. / num_grid)))
  745. top = max(top_box, coord_h - 1)
  746. down = min(down_box, coord_h + 1)
  747. left = max(coord_w - 1, left_box)
  748. right = min(right_box, coord_w + 1)
  749. cate_label[top:(down + 1), left:(right + 1)] = gt_label
  750. seg_mask = self._scale_size(
  751. seg_mask, scale=1. / self.sampling_ratio)
  752. for i in range(top, down + 1):
  753. for j in range(left, right + 1):
  754. label = int(i * num_grid + j)
  755. cur_ins_label = np.zeros(
  756. [mask_feat_size[0], mask_feat_size[1]],
  757. dtype=np.uint8)
  758. cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
  759. 1]] = seg_mask
  760. ins_label.append(cur_ins_label)
  761. ins_ind_label[label] = True
  762. grid_order.append(sample_id * num_grid * num_grid +
  763. label)
  764. if ins_label == []:
  765. ins_label = np.zeros(
  766. [1, mask_feat_size[0], mask_feat_size[1]],
  767. dtype=np.uint8)
  768. ins_ind_label_list.append(ins_ind_label)
  769. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  770. sample['ins_label{}'.format(idx)] = ins_label
  771. sample['grid_order{}'.format(idx)] = np.asarray(
  772. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  773. else:
  774. ins_label = np.stack(ins_label, axis=0)
  775. ins_ind_label_list.append(ins_ind_label)
  776. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  777. sample['ins_label{}'.format(idx)] = ins_label
  778. sample['grid_order{}'.format(idx)] = np.asarray(
  779. grid_order, dtype=np.int32)
  780. assert len(grid_order) > 0
  781. max_ins_num[idx] = max(
  782. max_ins_num[idx],
  783. sample['ins_label{}'.format(idx)].shape[0])
  784. idx += 1
  785. ins_ind_labels = np.concatenate([
  786. ins_ind_labels_level_img
  787. for ins_ind_labels_level_img in ins_ind_label_list
  788. ])
  789. fg_num = np.sum(ins_ind_labels)
  790. sample['fg_num'] = fg_num
  791. sample_id += 1
  792. sample.pop('is_crowd')
  793. sample.pop('gt_class')
  794. sample.pop('gt_bbox')
  795. sample.pop('gt_poly')
  796. sample.pop('gt_segm')
  797. # padding batch
  798. for data in samples:
  799. for idx in range(len(self.num_grids)):
  800. gt_ins_data = np.zeros(
  801. [
  802. max_ins_num[idx],
  803. data['ins_label{}'.format(idx)].shape[1],
  804. data['ins_label{}'.format(idx)].shape[2]
  805. ],
  806. dtype=np.uint8)
  807. gt_ins_data[0:data['ins_label{}'.format(idx)].shape[
  808. 0], :, :] = data['ins_label{}'.format(idx)]
  809. gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32)
  810. gt_grid_order[0:data['grid_order{}'.format(idx)].shape[
  811. 0]] = data['grid_order{}'.format(idx)]
  812. data['ins_label{}'.format(idx)] = gt_ins_data
  813. data['grid_order{}'.format(idx)] = gt_grid_order
  814. return samples
  815. @register_op
  816. class Gt2SparseRCNNTarget(BaseOperator):
  817. '''
  818. Generate SparseRCNN targets by groud truth data
  819. '''
  820. def __init__(self):
  821. super(Gt2SparseRCNNTarget, self).__init__()
  822. def __call__(self, samples, context=None):
  823. for sample in samples:
  824. im = sample["image"]
  825. h, w = im.shape[1:3]
  826. img_whwh = np.array([w, h, w, h], dtype=np.int32)
  827. sample["img_whwh"] = img_whwh
  828. if "scale_factor" in sample:
  829. sample["scale_factor_wh"] = np.array(
  830. [sample["scale_factor"][1], sample["scale_factor"][0]],
  831. dtype=np.float32)
  832. else:
  833. sample["scale_factor_wh"] = np.array(
  834. [1.0, 1.0], dtype=np.float32)
  835. return samples
  836. @register_op
  837. class PadMaskBatch(BaseOperator):
  838. """
  839. Pad a batch of samples so they can be divisible by a stride.
  840. The layout of each image should be 'CHW'.
  841. Args:
  842. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  843. height and width is divisible by `pad_to_stride`.
  844. return_pad_mask (bool): If `return_pad_mask = True`, return
  845. `pad_mask` for transformer.
  846. """
  847. def __init__(self, pad_to_stride=0, return_pad_mask=False):
  848. super(PadMaskBatch, self).__init__()
  849. self.pad_to_stride = pad_to_stride
  850. self.return_pad_mask = return_pad_mask
  851. def __call__(self, samples, context=None):
  852. """
  853. Args:
  854. samples (list): a batch of sample, each is dict.
  855. """
  856. coarsest_stride = self.pad_to_stride
  857. max_shape = np.array([data['image'].shape for data in samples]).max(
  858. axis=0)
  859. if coarsest_stride > 0:
  860. max_shape[1] = int(
  861. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  862. max_shape[2] = int(
  863. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  864. for data in samples:
  865. im = data['image']
  866. im_c, im_h, im_w = im.shape[:]
  867. padding_im = np.zeros(
  868. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  869. padding_im[:, :im_h, :im_w] = im
  870. data['image'] = padding_im
  871. if 'semantic' in data and data['semantic'] is not None:
  872. semantic = data['semantic']
  873. padding_sem = np.zeros(
  874. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  875. padding_sem[:, :im_h, :im_w] = semantic
  876. data['semantic'] = padding_sem
  877. if 'gt_segm' in data and data['gt_segm'] is not None:
  878. gt_segm = data['gt_segm']
  879. padding_segm = np.zeros(
  880. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  881. dtype=np.uint8)
  882. padding_segm[:, :im_h, :im_w] = gt_segm
  883. data['gt_segm'] = padding_segm
  884. if self.return_pad_mask:
  885. padding_mask = np.zeros(
  886. (max_shape[1], max_shape[2]), dtype=np.float32)
  887. padding_mask[:im_h, :im_w] = 1.
  888. data['pad_mask'] = padding_mask
  889. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  890. # ploy to rbox
  891. polys = data['gt_rbox2poly']
  892. rbox = bbox_utils.poly2rbox(polys)
  893. data['gt_rbox'] = rbox
  894. return samples
  895. @register_op
  896. class Gt2CenterNetTarget(BaseOperator):
  897. """Gt2CenterNetTarget
  898. Genterate CenterNet targets by ground-truth
  899. Args:
  900. down_ratio (int): The down sample ratio between output feature and
  901. input image.
  902. num_classes (int): The number of classes, 80 by default.
  903. max_objs (int): The maximum objects detected, 128 by default.
  904. """
  905. def __init__(self, down_ratio, num_classes=80, max_objs=128):
  906. super(Gt2CenterNetTarget, self).__init__()
  907. self.down_ratio = down_ratio
  908. self.num_classes = num_classes
  909. self.max_objs = max_objs
  910. def __call__(self, sample, context=None):
  911. input_h, input_w = sample['image'].shape[1:]
  912. output_h = input_h // self.down_ratio
  913. output_w = input_w // self.down_ratio
  914. num_classes = self.num_classes
  915. c = sample['center']
  916. s = sample['scale']
  917. gt_bbox = sample['gt_bbox']
  918. gt_class = sample['gt_class']
  919. hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
  920. wh = np.zeros((self.max_objs, 2), dtype=np.float32)
  921. dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
  922. reg = np.zeros((self.max_objs, 2), dtype=np.float32)
  923. ind = np.zeros((self.max_objs), dtype=np.int64)
  924. reg_mask = np.zeros((self.max_objs), dtype=np.int32)
  925. cat_spec_wh = np.zeros(
  926. (self.max_objs, num_classes * 2), dtype=np.float32)
  927. cat_spec_mask = np.zeros(
  928. (self.max_objs, num_classes * 2), dtype=np.int32)
  929. trans_output = get_affine_transform(c, [s, s], 0, [output_w, output_h])
  930. gt_det = []
  931. for i, (bbox, cls) in enumerate(zip(gt_bbox, gt_class)):
  932. cls = int(cls)
  933. bbox[:2] = affine_transform(bbox[:2], trans_output)
  934. bbox[2:] = affine_transform(bbox[2:], trans_output)
  935. bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
  936. bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
  937. h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
  938. if h > 0 and w > 0:
  939. radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
  940. radius = max(0, int(radius))
  941. ct = np.array(
  942. [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
  943. dtype=np.float32)
  944. ct_int = ct.astype(np.int32)
  945. draw_umich_gaussian(hm[cls], ct_int, radius)
  946. wh[i] = 1. * w, 1. * h
  947. ind[i] = ct_int[1] * output_w + ct_int[0]
  948. reg[i] = ct - ct_int
  949. reg_mask[i] = 1
  950. cat_spec_wh[i, cls * 2:cls * 2 + 2] = wh[i]
  951. cat_spec_mask[i, cls * 2:cls * 2 + 2] = 1
  952. gt_det.append([
  953. ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2,
  954. 1, cls
  955. ])
  956. sample.pop('gt_bbox', None)
  957. sample.pop('gt_class', None)
  958. sample.pop('center', None)
  959. sample.pop('scale', None)
  960. sample.pop('is_crowd', None)
  961. sample.pop('difficult', None)
  962. sample['heatmap'] = hm
  963. sample['index_mask'] = reg_mask
  964. sample['index'] = ind
  965. sample['size'] = wh
  966. sample['offset'] = reg
  967. return sample
  968. @register_op
  969. class PadGT(BaseOperator):
  970. """
  971. Pad 0 to `gt_class`, `gt_bbox`, `gt_score`...
  972. The num_max_boxes is the largest for batch.
  973. Args:
  974. return_gt_mask (bool): If true, return `pad_gt_mask`,
  975. 1 means bbox, 0 means no bbox.
  976. """
  977. def __init__(self, return_gt_mask=True):
  978. super(PadGT, self).__init__()
  979. self.return_gt_mask = return_gt_mask
  980. def __call__(self, samples, context=None):
  981. num_max_boxes = max([len(s['gt_bbox']) for s in samples])
  982. for sample in samples:
  983. if self.return_gt_mask:
  984. sample['pad_gt_mask'] = np.zeros(
  985. (num_max_boxes, 1), dtype=np.float32)
  986. if num_max_boxes == 0:
  987. continue
  988. num_gt = len(sample['gt_bbox'])
  989. pad_gt_class = np.zeros((num_max_boxes, 1), dtype=np.int32)
  990. pad_gt_bbox = np.zeros((num_max_boxes, 4), dtype=np.float32)
  991. if num_gt > 0:
  992. pad_gt_class[:num_gt] = sample['gt_class']
  993. pad_gt_bbox[:num_gt] = sample['gt_bbox']
  994. sample['gt_class'] = pad_gt_class
  995. sample['gt_bbox'] = pad_gt_bbox
  996. # pad_gt_mask
  997. if 'pad_gt_mask' in sample:
  998. sample['pad_gt_mask'][:num_gt] = 1
  999. # gt_score
  1000. if 'gt_score' in sample:
  1001. pad_gt_score = np.zeros((num_max_boxes, 1), dtype=np.float32)
  1002. if num_gt > 0:
  1003. pad_gt_score[:num_gt] = sample['gt_score']
  1004. sample['gt_score'] = pad_gt_score
  1005. if 'is_crowd' in sample:
  1006. pad_is_crowd = np.zeros((num_max_boxes, 1), dtype=np.int32)
  1007. if num_gt > 0:
  1008. pad_is_crowd[:num_gt] = sample['is_crowd']
  1009. sample['is_crowd'] = pad_is_crowd
  1010. if 'difficult' in sample:
  1011. pad_diff = np.zeros((num_max_boxes, 1), dtype=np.int32)
  1012. if num_gt > 0:
  1013. pad_diff[:num_gt] = sample['difficult']
  1014. sample['difficult'] = pad_diff
  1015. return samples