batch_operators.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. import cv2
  22. import math
  23. import numpy as np
  24. from .operators import register_op, BaseOperator, Resize
  25. from .op_helper import jaccard_overlap, gaussian2D
  26. from .atss_assigner import ATSSAssigner
  27. from scipy import ndimage
  28. from paddlex.ppdet.modeling import bbox_utils
  29. from paddlex.ppdet.utils.logger import setup_logger
  30. logger = setup_logger(__name__)
  31. __all__ = [
  32. 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
  33. 'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget', 'PadMaskBatch',
  34. 'Gt2GFLTarget'
  35. ]
  36. @register_op
  37. class PadBatch(BaseOperator):
  38. """
  39. Pad a batch of samples so they can be divisible by a stride.
  40. The layout of each image should be 'CHW'.
  41. Args:
  42. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  43. height and width is divisible by `pad_to_stride`.
  44. """
  45. def __init__(self, pad_to_stride=0):
  46. super(PadBatch, self).__init__()
  47. self.pad_to_stride = pad_to_stride
  48. def __call__(self, samples, context=None):
  49. """
  50. Args:
  51. samples (list): a batch of sample, each is dict.
  52. """
  53. coarsest_stride = self.pad_to_stride
  54. max_shape = np.array([data['image'].shape for data in samples]).max(
  55. axis=0)
  56. if coarsest_stride > 0:
  57. max_shape[1] = int(
  58. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  59. max_shape[2] = int(
  60. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  61. for data in samples:
  62. im = data['image']
  63. im_c, im_h, im_w = im.shape[:]
  64. padding_im = np.zeros(
  65. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  66. padding_im[:, :im_h, :im_w] = im
  67. data['image'] = padding_im
  68. if 'semantic' in data and data['semantic'] is not None:
  69. semantic = data['semantic']
  70. padding_sem = np.zeros(
  71. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  72. padding_sem[:, :im_h, :im_w] = semantic
  73. data['semantic'] = padding_sem
  74. if 'gt_segm' in data and data['gt_segm'] is not None:
  75. gt_segm = data['gt_segm']
  76. padding_segm = np.zeros(
  77. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  78. dtype=np.uint8)
  79. padding_segm[:, :im_h, :im_w] = gt_segm
  80. data['gt_segm'] = padding_segm
  81. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  82. # ploy to rbox
  83. polys = data['gt_rbox2poly']
  84. rbox = bbox_utils.poly2rbox(polys)
  85. data['gt_rbox'] = rbox
  86. return samples
  87. @register_op
  88. class BatchRandomResize(BaseOperator):
  89. """
  90. Resize image to target size randomly. random target_size and interpolation method
  91. Args:
  92. target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
  93. keep_ratio (bool): whether keep_raio or not, default true
  94. interp (int): the interpolation method
  95. random_size (bool): whether random select target size of image
  96. random_interp (bool): whether random select interpolation method
  97. """
  98. def __init__(self,
  99. target_size,
  100. keep_ratio,
  101. interp=cv2.INTER_NEAREST,
  102. random_size=True,
  103. random_interp=False):
  104. super(BatchRandomResize, self).__init__()
  105. self.keep_ratio = keep_ratio
  106. self.interps = [
  107. cv2.INTER_NEAREST,
  108. cv2.INTER_LINEAR,
  109. cv2.INTER_AREA,
  110. cv2.INTER_CUBIC,
  111. cv2.INTER_LANCZOS4,
  112. ]
  113. self.interp = interp
  114. assert isinstance(target_size, (
  115. int, Sequence)), "target_size must be int, list or tuple"
  116. if random_size and not isinstance(target_size, list):
  117. raise TypeError(
  118. "Type of target_size is invalid when random_size is True. Must be List, now is {}".
  119. format(type(target_size)))
  120. self.target_size = target_size
  121. self.random_size = random_size
  122. self.random_interp = random_interp
  123. def __call__(self, samples, context=None):
  124. if self.random_size:
  125. index = np.random.choice(len(self.target_size))
  126. target_size = self.target_size[index]
  127. else:
  128. target_size = self.target_size
  129. if self.random_interp:
  130. interp = np.random.choice(self.interps)
  131. else:
  132. interp = self.interp
  133. resizer = Resize(
  134. target_size, keep_ratio=self.keep_ratio, interp=interp)
  135. return resizer(samples, context=context)
  136. @register_op
  137. class Gt2YoloTarget(BaseOperator):
  138. """
  139. Generate YOLOv3 targets by groud truth data, this operator is only used in
  140. fine grained YOLOv3 loss mode
  141. """
  142. def __init__(self,
  143. anchors,
  144. anchor_masks,
  145. downsample_ratios,
  146. num_classes=80,
  147. iou_thresh=1.):
  148. super(Gt2YoloTarget, self).__init__()
  149. self.anchors = anchors
  150. self.anchor_masks = anchor_masks
  151. self.downsample_ratios = downsample_ratios
  152. self.num_classes = num_classes
  153. self.iou_thresh = iou_thresh
  154. def __call__(self, samples, context=None):
  155. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  156. "anchor_masks', and 'downsample_ratios' should have same length."
  157. h, w = samples[0]['image'].shape[1:3]
  158. an_hw = np.array(self.anchors) / np.array([[w, h]])
  159. for sample in samples:
  160. gt_bbox = sample['gt_bbox']
  161. gt_class = sample['gt_class']
  162. if 'gt_score' not in sample:
  163. sample['gt_score'] = np.ones(
  164. (gt_bbox.shape[0], 1), dtype=np.float32)
  165. gt_score = sample['gt_score']
  166. for i, (
  167. mask, downsample_ratio
  168. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  169. grid_h = int(h / downsample_ratio)
  170. grid_w = int(w / downsample_ratio)
  171. target = np.zeros(
  172. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  173. dtype=np.float32)
  174. for b in range(gt_bbox.shape[0]):
  175. gx, gy, gw, gh = gt_bbox[b, :]
  176. cls = gt_class[b]
  177. score = gt_score[b]
  178. if gw <= 0. or gh <= 0. or score <= 0.:
  179. continue
  180. # find best match anchor index
  181. best_iou = 0.
  182. best_idx = -1
  183. for an_idx in range(an_hw.shape[0]):
  184. iou = jaccard_overlap(
  185. [0., 0., gw, gh],
  186. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  187. if iou > best_iou:
  188. best_iou = iou
  189. best_idx = an_idx
  190. gi = int(gx * grid_w)
  191. gj = int(gy * grid_h)
  192. # gtbox should be regresed in this layes if best match
  193. # anchor index in anchor mask of this layer
  194. if best_idx in mask:
  195. best_n = mask.index(best_idx)
  196. # x, y, w, h, scale
  197. target[best_n, 0, gj, gi] = gx * grid_w - gi
  198. target[best_n, 1, gj, gi] = gy * grid_h - gj
  199. target[best_n, 2, gj, gi] = np.log(
  200. gw * w / self.anchors[best_idx][0])
  201. target[best_n, 3, gj, gi] = np.log(
  202. gh * h / self.anchors[best_idx][1])
  203. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  204. # objectness record gt_score
  205. target[best_n, 5, gj, gi] = score
  206. # classification
  207. target[best_n, 6 + cls, gj, gi] = 1.
  208. # For non-matched anchors, calculate the target if the iou
  209. # between anchor and gt is larger than iou_thresh
  210. if self.iou_thresh < 1:
  211. for idx, mask_i in enumerate(mask):
  212. if mask_i == best_idx: continue
  213. iou = jaccard_overlap(
  214. [0., 0., gw, gh],
  215. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  216. if iou > self.iou_thresh and target[idx, 5, gj,
  217. gi] == 0.:
  218. # x, y, w, h, scale
  219. target[idx, 0, gj, gi] = gx * grid_w - gi
  220. target[idx, 1, gj, gi] = gy * grid_h - gj
  221. target[idx, 2, gj, gi] = np.log(
  222. gw * w / self.anchors[mask_i][0])
  223. target[idx, 3, gj, gi] = np.log(
  224. gh * h / self.anchors[mask_i][1])
  225. target[idx, 4, gj, gi] = 2.0 - gw * gh
  226. # objectness record gt_score
  227. target[idx, 5, gj, gi] = score
  228. # classification
  229. target[idx, 6 + cls, gj, gi] = 1.
  230. sample['target{}'.format(i)] = target
  231. # remove useless gt_class and gt_score after target calculated
  232. sample.pop('gt_class')
  233. sample.pop('gt_score')
  234. return samples
  235. @register_op
  236. class Gt2FCOSTarget(BaseOperator):
  237. """
  238. Generate FCOS targets by groud truth data
  239. """
  240. def __init__(self,
  241. object_sizes_boundary,
  242. center_sampling_radius,
  243. downsample_ratios,
  244. norm_reg_targets=False):
  245. super(Gt2FCOSTarget, self).__init__()
  246. self.center_sampling_radius = center_sampling_radius
  247. self.downsample_ratios = downsample_ratios
  248. self.INF = np.inf
  249. self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
  250. object_sizes_of_interest = []
  251. for i in range(len(self.object_sizes_boundary) - 1):
  252. object_sizes_of_interest.append([
  253. self.object_sizes_boundary[i],
  254. self.object_sizes_boundary[i + 1]
  255. ])
  256. self.object_sizes_of_interest = object_sizes_of_interest
  257. self.norm_reg_targets = norm_reg_targets
  258. def _compute_points(self, w, h):
  259. """
  260. compute the corresponding points in each feature map
  261. :param h: image height
  262. :param w: image width
  263. :return: points from all feature map
  264. """
  265. locations = []
  266. for stride in self.downsample_ratios:
  267. shift_x = np.arange(0, w, stride).astype(np.float32)
  268. shift_y = np.arange(0, h, stride).astype(np.float32)
  269. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  270. shift_x = shift_x.flatten()
  271. shift_y = shift_y.flatten()
  272. location = np.stack([shift_x, shift_y], axis=1) + stride // 2
  273. locations.append(location)
  274. num_points_each_level = [len(location) for location in locations]
  275. locations = np.concatenate(locations, axis=0)
  276. return locations, num_points_each_level
  277. def _convert_xywh2xyxy(self, gt_bbox, w, h):
  278. """
  279. convert the bounding box from style xywh to xyxy
  280. :param gt_bbox: bounding boxes normalized into [0, 1]
  281. :param w: image width
  282. :param h: image height
  283. :return: bounding boxes in xyxy style
  284. """
  285. bboxes = gt_bbox.copy()
  286. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
  287. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
  288. bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
  289. bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
  290. return bboxes
  291. def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
  292. num_points_each_level):
  293. """
  294. check if points is within the clipped boxes
  295. :param gt_bbox: bounding boxes
  296. :param xs: horizontal coordinate of points
  297. :param ys: vertical coordinate of points
  298. :return: the mask of points is within gt_box or not
  299. """
  300. bboxes = np.reshape(
  301. gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
  302. bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
  303. ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
  304. ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
  305. beg = 0
  306. clipped_box = bboxes.copy()
  307. for lvl, stride in enumerate(self.downsample_ratios):
  308. end = beg + num_points_each_level[lvl]
  309. stride_exp = self.center_sampling_radius * stride
  310. clipped_box[beg:end, :, 0] = np.maximum(
  311. bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
  312. clipped_box[beg:end, :, 1] = np.maximum(
  313. bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
  314. clipped_box[beg:end, :, 2] = np.minimum(
  315. bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
  316. clipped_box[beg:end, :, 3] = np.minimum(
  317. bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
  318. beg = end
  319. l_res = xs - clipped_box[:, :, 0]
  320. r_res = clipped_box[:, :, 2] - xs
  321. t_res = ys - clipped_box[:, :, 1]
  322. b_res = clipped_box[:, :, 3] - ys
  323. clipped_box_reg_targets = np.stack(
  324. [l_res, t_res, r_res, b_res], axis=2)
  325. inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
  326. return inside_gt_box
  327. def __call__(self, samples, context=None):
  328. assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
  329. "object_sizes_of_interest', and 'downsample_ratios' should have same length."
  330. for sample in samples:
  331. im = sample['image']
  332. bboxes = sample['gt_bbox']
  333. gt_class = sample['gt_class']
  334. # calculate the locations
  335. h, w = im.shape[1:3]
  336. points, num_points_each_level = self._compute_points(w, h)
  337. object_scale_exp = []
  338. for i, num_pts in enumerate(num_points_each_level):
  339. object_scale_exp.append(
  340. np.tile(
  341. np.array([self.object_sizes_of_interest[i]]),
  342. reps=[num_pts, 1]))
  343. object_scale_exp = np.concatenate(object_scale_exp, axis=0)
  344. gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
  345. bboxes[:, 3] - bboxes[:, 1])
  346. xs, ys = points[:, 0], points[:, 1]
  347. xs = np.reshape(xs, newshape=[xs.shape[0], 1])
  348. xs = np.tile(xs, reps=[1, bboxes.shape[0]])
  349. ys = np.reshape(ys, newshape=[ys.shape[0], 1])
  350. ys = np.tile(ys, reps=[1, bboxes.shape[0]])
  351. l_res = xs - bboxes[:, 0]
  352. r_res = bboxes[:, 2] - xs
  353. t_res = ys - bboxes[:, 1]
  354. b_res = bboxes[:, 3] - ys
  355. reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
  356. if self.center_sampling_radius > 0:
  357. is_inside_box = self._check_inside_boxes_limited(
  358. bboxes, xs, ys, num_points_each_level)
  359. else:
  360. is_inside_box = np.min(reg_targets, axis=2) > 0
  361. # check if the targets is inside the corresponding level
  362. max_reg_targets = np.max(reg_targets, axis=2)
  363. lower_bound = np.tile(
  364. np.expand_dims(
  365. object_scale_exp[:, 0], axis=1),
  366. reps=[1, max_reg_targets.shape[1]])
  367. high_bound = np.tile(
  368. np.expand_dims(
  369. object_scale_exp[:, 1], axis=1),
  370. reps=[1, max_reg_targets.shape[1]])
  371. is_match_current_level = \
  372. (max_reg_targets > lower_bound) & \
  373. (max_reg_targets < high_bound)
  374. points2gtarea = np.tile(
  375. np.expand_dims(
  376. gt_area, axis=0), reps=[xs.shape[0], 1])
  377. points2gtarea[is_inside_box == 0] = self.INF
  378. points2gtarea[is_match_current_level == 0] = self.INF
  379. points2min_area = points2gtarea.min(axis=1)
  380. points2min_area_ind = points2gtarea.argmin(axis=1)
  381. labels = gt_class[points2min_area_ind] + 1
  382. labels[points2min_area == self.INF] = 0
  383. reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
  384. ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
  385. reg_targets[:, [0, 2]].max(axis=1)) * \
  386. (reg_targets[:, [1, 3]].min(axis=1) / \
  387. reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
  388. ctn_targets = np.reshape(
  389. ctn_targets, newshape=[ctn_targets.shape[0], 1])
  390. ctn_targets[labels <= 0] = 0
  391. pos_ind = np.nonzero(labels != 0)
  392. reg_targets_pos = reg_targets[pos_ind[0], :]
  393. split_sections = []
  394. beg = 0
  395. for lvl in range(len(num_points_each_level)):
  396. end = beg + num_points_each_level[lvl]
  397. split_sections.append(end)
  398. beg = end
  399. labels_by_level = np.split(labels, split_sections, axis=0)
  400. reg_targets_by_level = np.split(
  401. reg_targets, split_sections, axis=0)
  402. ctn_targets_by_level = np.split(
  403. ctn_targets, split_sections, axis=0)
  404. for lvl in range(len(self.downsample_ratios)):
  405. grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
  406. grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
  407. if self.norm_reg_targets:
  408. sample['reg_target{}'.format(lvl)] = \
  409. np.reshape(
  410. reg_targets_by_level[lvl] / \
  411. self.downsample_ratios[lvl],
  412. newshape=[grid_h, grid_w, 4])
  413. else:
  414. sample['reg_target{}'.format(lvl)] = np.reshape(
  415. reg_targets_by_level[lvl],
  416. newshape=[grid_h, grid_w, 4])
  417. sample['labels{}'.format(lvl)] = np.reshape(
  418. labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
  419. sample['centerness{}'.format(lvl)] = np.reshape(
  420. ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
  421. sample.pop('is_crowd', None)
  422. sample.pop('difficult', None)
  423. sample.pop('gt_class', None)
  424. sample.pop('gt_bbox', None)
  425. return samples
  426. @register_op
  427. class Gt2GFLTarget(BaseOperator):
  428. """
  429. Generate GFocal loss targets by groud truth data
  430. """
  431. def __init__(self,
  432. num_classes=80,
  433. downsample_ratios=[8, 16, 32, 64, 128],
  434. grid_cell_scale=4,
  435. cell_offset=0):
  436. super(Gt2GFLTarget, self).__init__()
  437. self.num_classes = num_classes
  438. self.downsample_ratios = downsample_ratios
  439. self.grid_cell_scale = grid_cell_scale
  440. self.cell_offset = cell_offset
  441. self.assigner = ATSSAssigner()
  442. def get_grid_cells(self, featmap_size, scale, stride, offset=0):
  443. """
  444. Generate grid cells of a feature map for target assignment.
  445. Args:
  446. featmap_size: Size of a single level feature map.
  447. scale: Grid cell scale.
  448. stride: Down sample stride of the feature map.
  449. offset: Offset of grid cells.
  450. return:
  451. Grid_cells xyxy position. Size should be [feat_w * feat_h, 4]
  452. """
  453. cell_size = stride * scale
  454. h, w = featmap_size
  455. x_range = (np.arange(w, dtype=np.float32) + offset) * stride
  456. y_range = (np.arange(h, dtype=np.float32) + offset) * stride
  457. x, y = np.meshgrid(x_range, y_range)
  458. y = y.flatten()
  459. x = x.flatten()
  460. grid_cells = np.stack(
  461. [
  462. x - 0.5 * cell_size, y - 0.5 * cell_size, x + 0.5 * cell_size,
  463. y + 0.5 * cell_size
  464. ],
  465. axis=-1)
  466. return grid_cells
  467. def get_sample(self, assign_gt_inds, gt_bboxes):
  468. pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
  469. neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
  470. pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1
  471. if gt_bboxes.size == 0:
  472. # hack for index error case
  473. assert pos_assigned_gt_inds.size == 0
  474. pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
  475. else:
  476. if len(gt_bboxes.shape) < 2:
  477. gt_bboxes = gt_bboxes.resize(-1, 4)
  478. pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
  479. return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds
  480. def __call__(self, samples, context=None):
  481. assert len(samples) > 0
  482. batch_size = len(samples)
  483. # get grid cells of image
  484. h, w = samples[0]['image'].shape[1:3]
  485. multi_level_grid_cells = []
  486. for stride in self.downsample_ratios:
  487. featmap_size = (int(math.ceil(h / stride)),
  488. int(math.ceil(w / stride)))
  489. multi_level_grid_cells.append(
  490. self.get_grid_cells(featmap_size, self.grid_cell_scale, stride,
  491. self.cell_offset))
  492. mlvl_grid_cells_list = [
  493. multi_level_grid_cells for i in range(batch_size)
  494. ]
  495. # pixel cell number of multi-level feature maps
  496. num_level_cells = [
  497. grid_cells.shape[0] for grid_cells in mlvl_grid_cells_list[0]
  498. ]
  499. num_level_cells_list = [num_level_cells] * batch_size
  500. # concat all level cells and to a single array
  501. for i in range(batch_size):
  502. mlvl_grid_cells_list[i] = np.concatenate(mlvl_grid_cells_list[i])
  503. # target assign on all images
  504. for sample, grid_cells, num_level_cells in zip(
  505. samples, mlvl_grid_cells_list, num_level_cells_list):
  506. gt_bboxes = sample['gt_bbox']
  507. gt_labels = sample['gt_class'].squeeze()
  508. if gt_labels.size == 1:
  509. gt_labels = np.array([gt_labels]).astype(np.int32)
  510. gt_bboxes_ignore = None
  511. assign_gt_inds, _ = self.assigner(grid_cells, num_level_cells,
  512. gt_bboxes, gt_bboxes_ignore,
  513. gt_labels)
  514. pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample(
  515. assign_gt_inds, gt_bboxes)
  516. num_cells = grid_cells.shape[0]
  517. bbox_targets = np.zeros_like(grid_cells)
  518. bbox_weights = np.zeros_like(grid_cells)
  519. labels = np.ones([num_cells], dtype=np.int64) * self.num_classes
  520. label_weights = np.zeros([num_cells], dtype=np.float32)
  521. if len(pos_inds) > 0:
  522. pos_bbox_targets = pos_gt_bboxes
  523. bbox_targets[pos_inds, :] = pos_bbox_targets
  524. bbox_weights[pos_inds, :] = 1.0
  525. if not np.any(gt_labels):
  526. labels[pos_inds] = 0
  527. else:
  528. labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
  529. label_weights[pos_inds] = 1.0
  530. if len(neg_inds) > 0:
  531. label_weights[neg_inds] = 1.0
  532. sample['grid_cells'] = grid_cells
  533. sample['labels'] = labels
  534. sample['label_weights'] = label_weights
  535. sample['bbox_targets'] = bbox_targets
  536. sample['pos_num'] = max(pos_inds.size, 1)
  537. sample.pop('is_crowd', None)
  538. sample.pop('difficult', None)
  539. sample.pop('gt_class', None)
  540. sample.pop('gt_bbox', None)
  541. sample.pop('gt_score', None)
  542. return samples
  543. @register_op
  544. class Gt2TTFTarget(BaseOperator):
  545. __shared__ = ['num_classes']
  546. """
  547. Gt2TTFTarget
  548. Generate TTFNet targets by ground truth data
  549. Args:
  550. num_classes(int): the number of classes.
  551. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  552. alpha(float): the alpha parameter to generate gaussian target.
  553. 0.54 by default.
  554. """
  555. def __init__(self, num_classes=80, down_ratio=4, alpha=0.54):
  556. super(Gt2TTFTarget, self).__init__()
  557. self.down_ratio = down_ratio
  558. self.num_classes = num_classes
  559. self.alpha = alpha
  560. def __call__(self, samples, context=None):
  561. output_size = samples[0]['image'].shape[1]
  562. feat_size = output_size // self.down_ratio
  563. for sample in samples:
  564. heatmap = np.zeros(
  565. (self.num_classes, feat_size, feat_size), dtype='float32')
  566. box_target = np.ones(
  567. (4, feat_size, feat_size), dtype='float32') * -1
  568. reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
  569. gt_bbox = sample['gt_bbox']
  570. gt_class = sample['gt_class']
  571. bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
  572. bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
  573. area = bbox_w * bbox_h
  574. boxes_areas_log = np.log(area)
  575. boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
  576. boxes_area_topk_log = boxes_areas_log[boxes_ind]
  577. gt_bbox = gt_bbox[boxes_ind]
  578. gt_class = gt_class[boxes_ind]
  579. feat_gt_bbox = gt_bbox / self.down_ratio
  580. feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
  581. feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
  582. feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
  583. ct_inds = np.stack(
  584. [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
  585. (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
  586. axis=1) / self.down_ratio
  587. h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
  588. w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
  589. for k in range(len(gt_bbox)):
  590. cls_id = gt_class[k]
  591. fake_heatmap = np.zeros(
  592. (feat_size, feat_size), dtype='float32')
  593. self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
  594. h_radiuses_alpha[k],
  595. w_radiuses_alpha[k])
  596. heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
  597. box_target_inds = fake_heatmap > 0
  598. box_target[:, box_target_inds] = gt_bbox[k][:, None]
  599. local_heatmap = fake_heatmap[box_target_inds]
  600. ct_div = np.sum(local_heatmap)
  601. local_heatmap *= boxes_area_topk_log[k]
  602. reg_weight[0, box_target_inds] = local_heatmap / ct_div
  603. sample['ttf_heatmap'] = heatmap
  604. sample['ttf_box_target'] = box_target
  605. sample['ttf_reg_weight'] = reg_weight
  606. sample.pop('is_crowd', None)
  607. sample.pop('difficult', None)
  608. sample.pop('gt_class', None)
  609. sample.pop('gt_bbox', None)
  610. sample.pop('gt_score', None)
  611. return samples
  612. def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
  613. h, w = 2 * h_radius + 1, 2 * w_radius + 1
  614. sigma_x = w / 6
  615. sigma_y = h / 6
  616. gaussian = gaussian2D((h, w), sigma_x, sigma_y)
  617. x, y = int(center[0]), int(center[1])
  618. height, width = heatmap.shape[0:2]
  619. left, right = min(x, w_radius), min(width - x, w_radius + 1)
  620. top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
  621. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  622. masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
  623. left:w_radius + right]
  624. if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
  625. heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
  626. masked_heatmap, masked_gaussian)
  627. return heatmap
  628. @register_op
  629. class Gt2Solov2Target(BaseOperator):
  630. """Assign mask target and labels in SOLOv2 network.
  631. Args:
  632. num_grids (list): The list of feature map grids size.
  633. scale_ranges (list): The list of mask boundary range.
  634. coord_sigma (float): The coefficient of coordinate area length.
  635. sampling_ratio (float): The ratio of down sampling.
  636. """
  637. def __init__(self,
  638. num_grids=[40, 36, 24, 16, 12],
  639. scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
  640. [384, 2048]],
  641. coord_sigma=0.2,
  642. sampling_ratio=4.0):
  643. super(Gt2Solov2Target, self).__init__()
  644. self.num_grids = num_grids
  645. self.scale_ranges = scale_ranges
  646. self.coord_sigma = coord_sigma
  647. self.sampling_ratio = sampling_ratio
  648. def _scale_size(self, im, scale):
  649. h, w = im.shape[:2]
  650. new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
  651. resized_img = cv2.resize(
  652. im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
  653. return resized_img
  654. def __call__(self, samples, context=None):
  655. sample_id = 0
  656. max_ins_num = [0] * len(self.num_grids)
  657. for sample in samples:
  658. gt_bboxes_raw = sample['gt_bbox']
  659. gt_labels_raw = sample['gt_class'] + 1
  660. im_c, im_h, im_w = sample['image'].shape[:]
  661. gt_masks_raw = sample['gt_segm'].astype(np.uint8)
  662. mask_feat_size = [
  663. int(im_h / self.sampling_ratio),
  664. int(im_w / self.sampling_ratio)
  665. ]
  666. gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
  667. (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
  668. ins_ind_label_list = []
  669. idx = 0
  670. for (lower_bound, upper_bound), num_grid \
  671. in zip(self.scale_ranges, self.num_grids):
  672. hit_indices = ((gt_areas >= lower_bound) &
  673. (gt_areas <= upper_bound)).nonzero()[0]
  674. num_ins = len(hit_indices)
  675. ins_label = []
  676. grid_order = []
  677. cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
  678. ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
  679. if num_ins == 0:
  680. ins_label = np.zeros(
  681. [1, mask_feat_size[0], mask_feat_size[1]],
  682. dtype=np.uint8)
  683. ins_ind_label_list.append(ins_ind_label)
  684. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  685. sample['ins_label{}'.format(idx)] = ins_label
  686. sample['grid_order{}'.format(idx)] = np.asarray(
  687. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  688. idx += 1
  689. continue
  690. gt_bboxes = gt_bboxes_raw[hit_indices]
  691. gt_labels = gt_labels_raw[hit_indices]
  692. gt_masks = gt_masks_raw[hit_indices, ...]
  693. half_ws = 0.5 * (
  694. gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
  695. half_hs = 0.5 * (
  696. gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
  697. for seg_mask, gt_label, half_h, half_w in zip(
  698. gt_masks, gt_labels, half_hs, half_ws):
  699. if seg_mask.sum() == 0:
  700. continue
  701. # mass center
  702. upsampled_size = (mask_feat_size[0] * 4,
  703. mask_feat_size[1] * 4)
  704. center_h, center_w = ndimage.measurements.center_of_mass(
  705. seg_mask)
  706. coord_w = int(
  707. (center_w / upsampled_size[1]) // (1. / num_grid))
  708. coord_h = int(
  709. (center_h / upsampled_size[0]) // (1. / num_grid))
  710. # left, top, right, down
  711. top_box = max(0,
  712. int(((center_h - half_h) / upsampled_size[0])
  713. // (1. / num_grid)))
  714. down_box = min(
  715. num_grid - 1,
  716. int(((center_h + half_h) / upsampled_size[0]) //
  717. (1. / num_grid)))
  718. left_box = max(
  719. 0,
  720. int(((center_w - half_w) / upsampled_size[1]) //
  721. (1. / num_grid)))
  722. right_box = min(num_grid - 1,
  723. int(((center_w + half_w) /
  724. upsampled_size[1]) //
  725. (1. / num_grid)))
  726. top = max(top_box, coord_h - 1)
  727. down = min(down_box, coord_h + 1)
  728. left = max(coord_w - 1, left_box)
  729. right = min(right_box, coord_w + 1)
  730. cate_label[top:(down + 1), left:(right + 1)] = gt_label
  731. seg_mask = self._scale_size(
  732. seg_mask, scale=1. / self.sampling_ratio)
  733. for i in range(top, down + 1):
  734. for j in range(left, right + 1):
  735. label = int(i * num_grid + j)
  736. cur_ins_label = np.zeros(
  737. [mask_feat_size[0], mask_feat_size[1]],
  738. dtype=np.uint8)
  739. cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
  740. 1]] = seg_mask
  741. ins_label.append(cur_ins_label)
  742. ins_ind_label[label] = True
  743. grid_order.append(sample_id * num_grid * num_grid +
  744. label)
  745. if ins_label == []:
  746. ins_label = np.zeros(
  747. [1, mask_feat_size[0], mask_feat_size[1]],
  748. dtype=np.uint8)
  749. ins_ind_label_list.append(ins_ind_label)
  750. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  751. sample['ins_label{}'.format(idx)] = ins_label
  752. sample['grid_order{}'.format(idx)] = np.asarray(
  753. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  754. else:
  755. ins_label = np.stack(ins_label, axis=0)
  756. ins_ind_label_list.append(ins_ind_label)
  757. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  758. sample['ins_label{}'.format(idx)] = ins_label
  759. sample['grid_order{}'.format(idx)] = np.asarray(
  760. grid_order, dtype=np.int32)
  761. assert len(grid_order) > 0
  762. max_ins_num[idx] = max(
  763. max_ins_num[idx],
  764. sample['ins_label{}'.format(idx)].shape[0])
  765. idx += 1
  766. ins_ind_labels = np.concatenate([
  767. ins_ind_labels_level_img
  768. for ins_ind_labels_level_img in ins_ind_label_list
  769. ])
  770. fg_num = np.sum(ins_ind_labels)
  771. sample['fg_num'] = fg_num
  772. sample_id += 1
  773. sample.pop('is_crowd')
  774. sample.pop('gt_class')
  775. sample.pop('gt_bbox')
  776. sample.pop('gt_poly')
  777. sample.pop('gt_segm')
  778. # padding batch
  779. for data in samples:
  780. for idx in range(len(self.num_grids)):
  781. gt_ins_data = np.zeros(
  782. [
  783. max_ins_num[idx],
  784. data['ins_label{}'.format(idx)].shape[1],
  785. data['ins_label{}'.format(idx)].shape[2]
  786. ],
  787. dtype=np.uint8)
  788. gt_ins_data[0:data['ins_label{}'.format(idx)].shape[
  789. 0], :, :] = data['ins_label{}'.format(idx)]
  790. gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32)
  791. gt_grid_order[0:data['grid_order{}'.format(idx)].shape[
  792. 0]] = data['grid_order{}'.format(idx)]
  793. data['ins_label{}'.format(idx)] = gt_ins_data
  794. data['grid_order{}'.format(idx)] = gt_grid_order
  795. return samples
  796. @register_op
  797. class Gt2SparseRCNNTarget(BaseOperator):
  798. '''
  799. Generate SparseRCNN targets by groud truth data
  800. '''
  801. def __init__(self):
  802. super(Gt2SparseRCNNTarget, self).__init__()
  803. def __call__(self, samples, context=None):
  804. for sample in samples:
  805. im = sample["image"]
  806. h, w = im.shape[1:3]
  807. img_whwh = np.array([w, h, w, h], dtype=np.int32)
  808. sample["img_whwh"] = img_whwh
  809. if "scale_factor" in sample:
  810. sample["scale_factor_wh"] = np.array(
  811. [sample["scale_factor"][1], sample["scale_factor"][0]],
  812. dtype=np.float32)
  813. else:
  814. sample["scale_factor_wh"] = np.array(
  815. [1.0, 1.0], dtype=np.float32)
  816. return samples
  817. @register_op
  818. class PadMaskBatch(BaseOperator):
  819. """
  820. Pad a batch of samples so they can be divisible by a stride.
  821. The layout of each image should be 'CHW'.
  822. Args:
  823. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  824. height and width is divisible by `pad_to_stride`.
  825. return_pad_mask (bool): If `return_pad_mask = True`, return
  826. `pad_mask` for transformer.
  827. """
  828. def __init__(self, pad_to_stride=0, return_pad_mask=False):
  829. super(PadMaskBatch, self).__init__()
  830. self.pad_to_stride = pad_to_stride
  831. self.return_pad_mask = return_pad_mask
  832. def __call__(self, samples, context=None):
  833. """
  834. Args:
  835. samples (list): a batch of sample, each is dict.
  836. """
  837. coarsest_stride = self.pad_to_stride
  838. max_shape = np.array([data['image'].shape for data in samples]).max(
  839. axis=0)
  840. if coarsest_stride > 0:
  841. max_shape[1] = int(
  842. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  843. max_shape[2] = int(
  844. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  845. for data in samples:
  846. im = data['image']
  847. im_c, im_h, im_w = im.shape[:]
  848. padding_im = np.zeros(
  849. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  850. padding_im[:, :im_h, :im_w] = im
  851. data['image'] = padding_im
  852. if 'semantic' in data and data['semantic'] is not None:
  853. semantic = data['semantic']
  854. padding_sem = np.zeros(
  855. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  856. padding_sem[:, :im_h, :im_w] = semantic
  857. data['semantic'] = padding_sem
  858. if 'gt_segm' in data and data['gt_segm'] is not None:
  859. gt_segm = data['gt_segm']
  860. padding_segm = np.zeros(
  861. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  862. dtype=np.uint8)
  863. padding_segm[:, :im_h, :im_w] = gt_segm
  864. data['gt_segm'] = padding_segm
  865. if self.return_pad_mask:
  866. padding_mask = np.zeros(
  867. (max_shape[1], max_shape[2]), dtype=np.float32)
  868. padding_mask[:im_h, :im_w] = 1.
  869. data['pad_mask'] = padding_mask
  870. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  871. # ploy to rbox
  872. polys = data['gt_rbox2poly']
  873. rbox = bbox_utils.poly2rbox(polys)
  874. data['gt_rbox'] = rbox
  875. return samples