batch_operators.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. import multiprocessing as mp
  16. import random
  17. import numpy as np
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. from paddle.fluid.dataloader.collate import default_collate_fn
  23. from .operators import Transform, Resize, ResizeByShort, _Permute, interp_dict
  24. from .box_utils import jaccard_overlap
  25. from paddlex.utils import logging
  26. class BatchCompose(Transform):
  27. def __init__(self, batch_transforms=None):
  28. super(BatchCompose, self).__init__()
  29. self.batch_transforms = batch_transforms
  30. self.lock = mp.Lock()
  31. def __call__(self, samples):
  32. if self.batch_transforms is not None:
  33. for op in self.batch_transforms:
  34. try:
  35. samples = op(samples)
  36. except Exception as e:
  37. stack_info = traceback.format_exc()
  38. logging.warning("fail to map batch transform [{}] "
  39. "with error: {} and stack:\n{}".format(
  40. op, e, str(stack_info)))
  41. raise e
  42. samples = _Permute()(samples)
  43. batch_data = default_collate_fn(samples)
  44. return batch_data
  45. class BatchRandomResize(Transform):
  46. """
  47. Resize a batch of input to random sizes.
  48. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  49. Args:
  50. target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
  51. Multiple target sizes, each target size is an int or list/tuple of length 2.
  52. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  53. Interpolation method of resize. Defaults to 'LINEAR'.
  54. Raises:
  55. TypeError: Invalid type of target_size.
  56. ValueError: Invalid interpolation method.
  57. See Also:
  58. RandomResize: Resize input to random sizes.
  59. """
  60. def __init__(self, target_sizes, interp='NEAREST'):
  61. super(BatchRandomResize, self).__init__()
  62. if not (interp == "RANDOM" or interp in interp_dict):
  63. raise ValueError("interp should be one of {}".format(
  64. interp_dict.keys()))
  65. self.interp = interp
  66. assert isinstance(target_sizes, list), \
  67. "target_size must be List"
  68. for i, item in enumerate(target_sizes):
  69. if isinstance(item, int):
  70. target_sizes[i] = (item, item)
  71. self.target_size = target_sizes
  72. def __call__(self, samples):
  73. height, width = random.choice(self.target_size)
  74. resizer = Resize((height, width), interp=self.interp)
  75. samples = resizer(samples)
  76. return samples
  77. class BatchRandomResizeByShort(Transform):
  78. """Resize a batch of input to random sizes with keeping the aspect ratio.
  79. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  80. Args:
  81. short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).
  82. max_size (int, optional): The upper bound of longer side of the image(s).
  83. If max_size is -1, no upper bound is applied. Defaults to -1.
  84. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  85. Interpolation method of resize. Defaults to 'LINEAR'.
  86. Raises:
  87. TypeError: Invalid type of target_size.
  88. ValueError: Invalid interpolation method.
  89. See Also:
  90. RandomResizeByShort: Resize input to random sizes with keeping the aspect ratio.
  91. """
  92. def __init__(self, short_sizes, max_size=-1, interp='NEAREST'):
  93. super(BatchRandomResizeByShort, self).__init__()
  94. if not (interp == "RANDOM" or interp in interp_dict):
  95. raise ValueError("interp should be one of {}".format(
  96. interp_dict.keys()))
  97. self.interp = interp
  98. assert isinstance(short_sizes, list), \
  99. "short_sizes must be List"
  100. self.short_sizes = short_sizes
  101. self.max_size = max_size
  102. def __call__(self, samples):
  103. short_size = random.choice(self.short_sizes)
  104. resizer = ResizeByShort(
  105. short_size=short_size, max_size=self.max_size, interp=self.interp)
  106. samples = resizer(samples)
  107. return samples
  108. class _BatchPadding(Transform):
  109. def __init__(self, pad_to_stride=0, pad_gt=False):
  110. super(_BatchPadding, self).__init__()
  111. self.pad_to_stride = pad_to_stride
  112. self.pad_gt = pad_gt
  113. def __call__(self, samples):
  114. coarsest_stride = self.pad_to_stride
  115. max_shape = np.array([data['image'].shape for data in samples]).max(
  116. axis=0)
  117. if coarsest_stride > 0:
  118. max_shape[0] = int(
  119. np.ceil(max_shape[0] / coarsest_stride) * coarsest_stride)
  120. max_shape[1] = int(
  121. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  122. for data in samples:
  123. im = data['image']
  124. im_h, im_w, im_c = im.shape[:]
  125. padding_im = np.zeros(
  126. (max_shape[0], max_shape[1], im_c), dtype=np.float32)
  127. padding_im[:im_h, :im_w, :] = im
  128. data['image'] = padding_im
  129. if self.pad_gt:
  130. gt_num = []
  131. if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
  132. 'gt_poly']) > 0:
  133. pad_mask = True
  134. else:
  135. pad_mask = False
  136. if pad_mask:
  137. poly_num = []
  138. poly_part_num = []
  139. point_num = []
  140. for data in samples:
  141. gt_num.append(data['gt_bbox'].shape[0])
  142. if pad_mask:
  143. poly_num.append(len(data['gt_poly']))
  144. for poly in data['gt_poly']:
  145. poly_part_num.append(int(len(poly)))
  146. for p_p in poly:
  147. point_num.append(int(len(p_p) / 2))
  148. gt_num_max = max(gt_num)
  149. for i, data in enumerate(samples):
  150. gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
  151. gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
  152. is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
  153. if pad_mask:
  154. poly_num_max = max(poly_num)
  155. poly_part_num_max = max(poly_part_num)
  156. point_num_max = max(point_num)
  157. gt_masks_data = -np.ones(
  158. [poly_num_max, poly_part_num_max, point_num_max, 2],
  159. dtype=np.float32)
  160. gt_num = data['gt_bbox'].shape[0]
  161. gt_box_data[0:gt_num, :] = data['gt_bbox']
  162. gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
  163. if 'is_crowd' in data:
  164. is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
  165. data['is_crowd'] = is_crowd_data
  166. data['gt_bbox'] = gt_box_data
  167. data['gt_class'] = gt_class_data
  168. if pad_mask:
  169. for j, poly in enumerate(data['gt_poly']):
  170. for k, p_p in enumerate(poly):
  171. pp_np = np.array(p_p).reshape(-1, 2)
  172. gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
  173. data['gt_poly'] = gt_masks_data
  174. if 'gt_score' in data:
  175. gt_score_data = np.zeros([gt_num_max], dtype=np.float32)
  176. gt_score_data[0:gt_num] = data['gt_score'][:gt_num, 0]
  177. data['gt_score'] = gt_score_data
  178. if 'difficult' in data:
  179. diff_data = np.zeros([gt_num_max], dtype=np.int32)
  180. diff_data[0:gt_num] = data['difficult'][:gt_num, 0]
  181. data['difficult'] = diff_data
  182. return samples
  183. class _Gt2YoloTarget(Transform):
  184. """
  185. Generate YOLOv3 targets by groud truth data, this operator is only used in
  186. fine grained YOLOv3 loss mode
  187. """
  188. def __init__(self,
  189. anchors,
  190. anchor_masks,
  191. downsample_ratios,
  192. num_classes=80,
  193. iou_thresh=1.):
  194. super(_Gt2YoloTarget, self).__init__()
  195. self.anchors = anchors
  196. self.anchor_masks = anchor_masks
  197. self.downsample_ratios = downsample_ratios
  198. self.num_classes = num_classes
  199. self.iou_thresh = iou_thresh
  200. def __call__(self, samples, context=None):
  201. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  202. "anchor_masks', and 'downsample_ratios' should have same length."
  203. h, w = samples[0]['image'].shape[:2]
  204. an_hw = np.array(self.anchors) / np.array([[w, h]])
  205. for sample in samples:
  206. gt_bbox = sample['gt_bbox']
  207. gt_class = sample['gt_class']
  208. if 'gt_score' not in sample:
  209. sample['gt_score'] = np.ones(
  210. (gt_bbox.shape[0], 1), dtype=np.float32)
  211. gt_score = sample['gt_score']
  212. for i, (
  213. mask, downsample_ratio
  214. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  215. grid_h = int(h / downsample_ratio)
  216. grid_w = int(w / downsample_ratio)
  217. target = np.zeros(
  218. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  219. dtype=np.float32)
  220. for b in range(gt_bbox.shape[0]):
  221. gx, gy, gw, gh = gt_bbox[b, :]
  222. cls = gt_class[b]
  223. score = gt_score[b]
  224. if gw <= 0. or gh <= 0. or score <= 0.:
  225. continue
  226. # find best match anchor index
  227. best_iou = 0.
  228. best_idx = -1
  229. for an_idx in range(an_hw.shape[0]):
  230. iou = jaccard_overlap(
  231. [0., 0., gw, gh],
  232. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  233. if iou > best_iou:
  234. best_iou = iou
  235. best_idx = an_idx
  236. gi = int(gx * grid_w)
  237. gj = int(gy * grid_h)
  238. # gtbox should be regresed in this layes if best match
  239. # anchor index in anchor mask of this layer
  240. if best_idx in mask:
  241. best_n = mask.index(best_idx)
  242. # x, y, w, h, scale
  243. target[best_n, 0, gj, gi] = gx * grid_w - gi
  244. target[best_n, 1, gj, gi] = gy * grid_h - gj
  245. target[best_n, 2, gj, gi] = np.log(
  246. gw * w / self.anchors[best_idx][0])
  247. target[best_n, 3, gj, gi] = np.log(
  248. gh * h / self.anchors[best_idx][1])
  249. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  250. # objectness record gt_score
  251. target[best_n, 5, gj, gi] = score
  252. # classification
  253. target[best_n, 6 + cls, gj, gi] = 1.
  254. # For non-matched anchors, calculate the target if the iou
  255. # between anchor and gt is larger than iou_thresh
  256. if self.iou_thresh < 1:
  257. for idx, mask_i in enumerate(mask):
  258. if mask_i == best_idx: continue
  259. iou = jaccard_overlap(
  260. [0., 0., gw, gh],
  261. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  262. if iou > self.iou_thresh and target[idx, 5, gj,
  263. gi] == 0.:
  264. # x, y, w, h, scale
  265. target[idx, 0, gj, gi] = gx * grid_w - gi
  266. target[idx, 1, gj, gi] = gy * grid_h - gj
  267. target[idx, 2, gj, gi] = np.log(
  268. gw * w / self.anchors[mask_i][0])
  269. target[idx, 3, gj, gi] = np.log(
  270. gh * h / self.anchors[mask_i][1])
  271. target[idx, 4, gj, gi] = 2.0 - gw * gh
  272. # objectness record gt_score
  273. target[idx, 5, gj, gi] = score
  274. # classification
  275. target[idx, 5 + cls, gj, gi] = 1.
  276. sample['target{}'.format(i)] = target
  277. # remove useless gt_class and gt_score after target calculated
  278. sample.pop('gt_class')
  279. sample.pop('gt_score')
  280. return samples