infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections.abc
  15. from itertools import combinations
  16. import numpy as np
  17. import cv2
  18. import paddle
  19. import paddle.nn.functional as F
  20. def get_reverse_list(ori_shape, transforms):
  21. """
  22. get reverse list of transform.
  23. Args:
  24. ori_shape (list): Origin shape of image.
  25. transforms (list): List of transform.
  26. Returns:
  27. list: List of tuple, there are two format:
  28. ('resize', (h, w)) The image shape before resize,
  29. ('padding', (h, w)) The image shape before padding.
  30. """
  31. reverse_list = []
  32. h, w = ori_shape[0], ori_shape[1]
  33. for op in transforms:
  34. if op.__class__.__name__ in ['Resize']:
  35. reverse_list.append(('resize', (h, w)))
  36. h, w = op.target_size[0], op.target_size[1]
  37. if op.__class__.__name__ in ['ResizeByLong']:
  38. reverse_list.append(('resize', (h, w)))
  39. long_edge = max(h, w)
  40. short_edge = min(h, w)
  41. short_edge = int(round(short_edge * op.long_size / long_edge))
  42. long_edge = op.long_size
  43. if h > w:
  44. h = long_edge
  45. w = short_edge
  46. else:
  47. w = long_edge
  48. h = short_edge
  49. if op.__class__.__name__ in ['Padding']:
  50. reverse_list.append(('padding', (h, w)))
  51. w, h = op.target_size[0], op.target_size[1]
  52. if op.__class__.__name__ in ['PaddingByAspectRatio']:
  53. reverse_list.append(('padding', (h, w)))
  54. ratio = w / h
  55. if ratio == op.aspect_ratio:
  56. pass
  57. elif ratio > op.aspect_ratio:
  58. h = int(w / op.aspect_ratio)
  59. else:
  60. w = int(h * op.aspect_ratio)
  61. if op.__class__.__name__ in ['LimitLong']:
  62. long_edge = max(h, w)
  63. short_edge = min(h, w)
  64. if ((op.max_long is not None) and (long_edge > op.max_long)):
  65. reverse_list.append(('resize', (h, w)))
  66. long_edge = op.max_long
  67. short_edge = int(round(short_edge * op.max_long / long_edge))
  68. elif ((op.min_long is not None) and (long_edge < op.min_long)):
  69. reverse_list.append(('resize', (h, w)))
  70. long_edge = op.min_long
  71. short_edge = int(round(short_edge * op.min_long / long_edge))
  72. if h > w:
  73. h = long_edge
  74. w = short_edge
  75. else:
  76. w = long_edge
  77. h = short_edge
  78. return reverse_list
  79. def reverse_transform(pred, ori_shape, transforms, mode='nearest'):
  80. """recover pred to origin shape"""
  81. reverse_list = get_reverse_list(ori_shape, transforms)
  82. for item in reverse_list[::-1]:
  83. if item[0] == 'resize':
  84. h, w = item[1][0], item[1][1]
  85. if paddle.get_device() == 'cpu':
  86. pred = paddle.cast(pred, 'uint8')
  87. pred = F.interpolate(pred, (h, w), mode=mode)
  88. pred = paddle.cast(pred, 'int32')
  89. else:
  90. pred = F.interpolate(pred, (h, w), mode=mode)
  91. elif item[0] == 'padding':
  92. h, w = item[1][0], item[1][1]
  93. pred = pred[:, :, 0:h, 0:w]
  94. else:
  95. raise Exception("Unexpected info '{}' in im_info".format(item[0]))
  96. return pred
  97. def flip_combination(flip_horizontal=False, flip_vertical=False):
  98. """
  99. Get flip combination.
  100. Args:
  101. flip_horizontal (bool): Whether to flip horizontally. Default: False.
  102. flip_vertical (bool): Whether to flip vertically. Default: False.
  103. Returns:
  104. list: List of tuple. The first element of tuple is whether to flip horizontally,
  105. and the second is whether to flip vertically.
  106. """
  107. flip_comb = [(False, False)]
  108. if flip_horizontal:
  109. flip_comb.append((True, False))
  110. if flip_vertical:
  111. flip_comb.append((False, True))
  112. if flip_horizontal:
  113. flip_comb.append((True, True))
  114. return flip_comb
  115. def tensor_flip(x, flip):
  116. """Flip tensor according directions"""
  117. if flip[0]:
  118. x = x[:, :, :, ::-1]
  119. if flip[1]:
  120. x = x[:, :, ::-1, :]
  121. return x
  122. def slide_inference(model, im, crop_size, stride):
  123. """
  124. Infer by sliding window.
  125. Args:
  126. model (paddle.nn.Layer): model to get logits of image.
  127. im (Tensor): the input image.
  128. crop_size (tuple|list). The size of sliding window, (w, h).
  129. stride (tuple|list). The size of stride, (w, h).
  130. Return:
  131. Tensor: The logit of input image.
  132. """
  133. h_im, w_im = im.shape[-2:]
  134. w_crop, h_crop = crop_size
  135. w_stride, h_stride = stride
  136. # calculate the crop nums
  137. rows = np.int(np.ceil(1.0 * (h_im - h_crop) / h_stride)) + 1
  138. cols = np.int(np.ceil(1.0 * (w_im - w_crop) / w_stride)) + 1
  139. # prevent negative sliding rounds when imgs after scaling << crop_size
  140. rows = 1 if h_im <= h_crop else rows
  141. cols = 1 if w_im <= w_crop else cols
  142. # TODO 'Tensor' object does not support item assignment. If support, use tensor to calculation.
  143. final_logit = None
  144. count = np.zeros([1, 1, h_im, w_im])
  145. for r in range(rows):
  146. for c in range(cols):
  147. h1 = r * h_stride
  148. w1 = c * w_stride
  149. h2 = min(h1 + h_crop, h_im)
  150. w2 = min(w1 + w_crop, w_im)
  151. h1 = max(h2 - h_crop, 0)
  152. w1 = max(w2 - w_crop, 0)
  153. im_crop = im[:, :, h1:h2, w1:w2]
  154. logits = model(im_crop)
  155. if not isinstance(logits, collections.abc.Sequence):
  156. raise TypeError(
  157. "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
  158. .format(type(logits)))
  159. logit = logits[0].numpy()
  160. if final_logit is None:
  161. final_logit = np.zeros([1, logit.shape[1], h_im, w_im])
  162. final_logit[:, :, h1:h2, w1:w2] += logit[:, :, :h2 - h1, :w2 - w1]
  163. count[:, :, h1:h2, w1:w2] += 1
  164. if np.sum(count == 0) != 0:
  165. raise RuntimeError(
  166. 'There are pixel not predicted. It is possible that stride is greater than crop_size'
  167. )
  168. final_logit = final_logit / count
  169. final_logit = paddle.to_tensor(final_logit)
  170. return final_logit
  171. def inference(model,
  172. im,
  173. ori_shape=None,
  174. transforms=None,
  175. is_slide=False,
  176. stride=None,
  177. crop_size=None):
  178. """
  179. Inference for image.
  180. Args:
  181. model (paddle.nn.Layer): model to get logits of image.
  182. im (Tensor): the input image.
  183. ori_shape (list): Origin shape of image.
  184. transforms (list): Transforms for image.
  185. is_slide (bool): Whether to infer by sliding window. Default: False.
  186. crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
  187. stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
  188. Returns:
  189. Tensor: If ori_shape is not None, a prediction with shape (1, 1, h, w) is returned.
  190. If ori_shape is None, a logit with shape (1, num_classes, h, w) is returned.
  191. """
  192. if not is_slide:
  193. logits = model(im)
  194. if not isinstance(logits, collections.abc.Sequence):
  195. raise TypeError(
  196. "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
  197. .format(type(logits)))
  198. logit = logits[0]
  199. else:
  200. logit = slide_inference(model, im, crop_size=crop_size, stride=stride)
  201. if ori_shape is not None:
  202. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  203. pred = reverse_transform(pred, ori_shape, transforms)
  204. return pred
  205. else:
  206. return logit
  207. def aug_inference(model,
  208. im,
  209. ori_shape,
  210. transforms,
  211. scales=1.0,
  212. flip_horizontal=False,
  213. flip_vertical=False,
  214. is_slide=False,
  215. stride=None,
  216. crop_size=None):
  217. """
  218. Infer with augmentation.
  219. Args:
  220. model (paddle.nn.Layer): model to get logits of image.
  221. im (Tensor): the input image.
  222. ori_shape (list): Origin shape of image.
  223. transforms (list): Transforms for image.
  224. scales (float|tuple|list): Scales for resize. Default: 1.
  225. flip_horizontal (bool): Whether to flip horizontally. Default: False.
  226. flip_vertical (bool): Whether to flip vertically. Default: False.
  227. is_slide (bool): Whether to infer by sliding wimdow. Default: False.
  228. crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
  229. stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
  230. Returns:
  231. Tensor: Prediction of image with shape (1, 1, h, w) is returned.
  232. """
  233. if isinstance(scales, float):
  234. scales = [scales]
  235. elif not isinstance(scales, (tuple, list)):
  236. raise TypeError(
  237. '`scales` expects float/tuple/list type, but received {}'.format(
  238. type(scales)))
  239. final_logit = 0
  240. h_input, w_input = im.shape[-2], im.shape[-1]
  241. flip_comb = flip_combination(flip_horizontal, flip_vertical)
  242. for scale in scales:
  243. h = int(h_input * scale + 0.5)
  244. w = int(w_input * scale + 0.5)
  245. im = F.interpolate(im, (h, w), mode='bilinear')
  246. for flip in flip_comb:
  247. im_flip = tensor_flip(im, flip)
  248. logit = inference(
  249. model,
  250. im_flip,
  251. is_slide=is_slide,
  252. crop_size=crop_size,
  253. stride=stride)
  254. logit = tensor_flip(logit, flip)
  255. logit = F.interpolate(logit, (h_input, w_input), mode='bilinear')
  256. logit = F.softmax(logit, axis=1)
  257. final_logit = final_logit + logit
  258. pred = paddle.argmax(final_logit, axis=1, keepdim=True, dtype='int32')
  259. pred = reverse_transform(pred, ori_shape, transforms)
  260. return pred