operators.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. from functools import partial
  19. import six
  20. import math
  21. import random
  22. import cv2
  23. import numpy as np
  24. from PIL import Image
  25. from paddle.vision.transforms import ColorJitter as RawColorJitter
  26. from .autoaugment import ImageNetPolicy
  27. from .functional import augmentations
  28. from paddlex.ppcls.utils import logger
  29. class UnifiedResize(object):
  30. def __init__(self, interpolation=None, backend="cv2"):
  31. _cv2_interp_from_str = {
  32. 'nearest': cv2.INTER_NEAREST,
  33. 'bilinear': cv2.INTER_LINEAR,
  34. 'area': cv2.INTER_AREA,
  35. 'bicubic': cv2.INTER_CUBIC,
  36. 'lanczos': cv2.INTER_LANCZOS4
  37. }
  38. _pil_interp_from_str = {
  39. 'nearest': Image.NEAREST,
  40. 'bilinear': Image.BILINEAR,
  41. 'bicubic': Image.BICUBIC,
  42. 'box': Image.BOX,
  43. 'lanczos': Image.LANCZOS,
  44. 'hamming': Image.HAMMING
  45. }
  46. def _pil_resize(src, size, resample):
  47. pil_img = Image.fromarray(src)
  48. pil_img = pil_img.resize(size, resample)
  49. return np.asarray(pil_img)
  50. if backend.lower() == "cv2":
  51. if isinstance(interpolation, str):
  52. interpolation = _cv2_interp_from_str[interpolation.lower()]
  53. # compatible with opencv < version 4.4.0
  54. elif interpolation is None:
  55. interpolation = cv2.INTER_LINEAR
  56. self.resize_func = partial(cv2.resize, interpolation=interpolation)
  57. elif backend.lower() == "pil":
  58. if isinstance(interpolation, str):
  59. interpolation = _pil_interp_from_str[interpolation.lower()]
  60. self.resize_func = partial(_pil_resize, resample=interpolation)
  61. else:
  62. logger.warning(
  63. f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
  64. )
  65. self.resize_func = cv2.resize
  66. def __call__(self, src, size):
  67. return self.resize_func(src, size)
  68. class OperatorParamError(ValueError):
  69. """ OperatorParamError
  70. """
  71. pass
  72. class DecodeImage(object):
  73. """ decode image """
  74. def __init__(self, to_rgb=True, to_np=False, channel_first=False):
  75. self.to_rgb = to_rgb
  76. self.to_np = to_np # to numpy
  77. self.channel_first = channel_first # only enabled when to_np is True
  78. def __call__(self, img):
  79. if six.PY2:
  80. assert type(img) is str and len(
  81. img) > 0, "invalid input 'img' in DecodeImage"
  82. else:
  83. assert type(img) is bytes and len(
  84. img) > 0, "invalid input 'img' in DecodeImage"
  85. data = np.frombuffer(img, dtype='uint8')
  86. img = cv2.imdecode(data, 1)
  87. if self.to_rgb:
  88. assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
  89. img.shape)
  90. img = img[:, :, ::-1]
  91. if self.channel_first:
  92. img = img.transpose((2, 0, 1))
  93. return img
  94. class ResizeImage(object):
  95. """ resize image """
  96. def __init__(self,
  97. size=None,
  98. resize_short=None,
  99. interpolation=None,
  100. backend="cv2"):
  101. if resize_short is not None and resize_short > 0:
  102. self.resize_short = resize_short
  103. self.w = None
  104. self.h = None
  105. elif size is not None:
  106. self.resize_short = None
  107. self.w = size if type(size) is int else size[0]
  108. self.h = size if type(size) is int else size[1]
  109. else:
  110. raise OperatorParamError("invalid params for ReisizeImage for '\
  111. 'both 'size' and 'resize_short' are None")
  112. self._resize_func = UnifiedResize(
  113. interpolation=interpolation, backend=backend)
  114. def __call__(self, img):
  115. img_h, img_w = img.shape[:2]
  116. if self.resize_short is not None:
  117. percent = float(self.resize_short) / min(img_w, img_h)
  118. w = int(round(img_w * percent))
  119. h = int(round(img_h * percent))
  120. else:
  121. w = self.w
  122. h = self.h
  123. return self._resize_func(img, (w, h))
  124. class CropImage(object):
  125. """ crop image """
  126. def __init__(self, size):
  127. if type(size) is int:
  128. self.size = (size, size)
  129. else:
  130. self.size = size # (h, w)
  131. def __call__(self, img):
  132. w, h = self.size
  133. img_h, img_w = img.shape[:2]
  134. w_start = (img_w - w) // 2
  135. h_start = (img_h - h) // 2
  136. w_end = w_start + w
  137. h_end = h_start + h
  138. return img[h_start:h_end, w_start:w_end, :]
  139. class RandCropImage(object):
  140. """ random crop image """
  141. def __init__(self,
  142. size,
  143. scale=None,
  144. ratio=None,
  145. interpolation=None,
  146. backend="cv2"):
  147. if type(size) is int:
  148. self.size = (size, size) # (h, w)
  149. else:
  150. self.size = size
  151. self.scale = [0.08, 1.0] if scale is None else scale
  152. self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
  153. self._resize_func = UnifiedResize(
  154. interpolation=interpolation, backend=backend)
  155. def __call__(self, img):
  156. size = self.size
  157. scale = self.scale
  158. ratio = self.ratio
  159. aspect_ratio = math.sqrt(random.uniform(*ratio))
  160. w = 1. * aspect_ratio
  161. h = 1. / aspect_ratio
  162. img_h, img_w = img.shape[:2]
  163. bound = min((float(img_w) / img_h) / (w**2),
  164. (float(img_h) / img_w) / (h**2))
  165. scale_max = min(scale[1], bound)
  166. scale_min = min(scale[0], bound)
  167. target_area = img_w * img_h * random.uniform(scale_min, scale_max)
  168. target_size = math.sqrt(target_area)
  169. w = int(target_size * w)
  170. h = int(target_size * h)
  171. i = random.randint(0, img_w - w)
  172. j = random.randint(0, img_h - h)
  173. img = img[j:j + h, i:i + w, :]
  174. return self._resize_func(img, size)
  175. class RandFlipImage(object):
  176. """ random flip image
  177. flip_code:
  178. 1: Flipped Horizontally
  179. 0: Flipped Vertically
  180. -1: Flipped Horizontally & Vertically
  181. """
  182. def __init__(self, flip_code=1):
  183. assert flip_code in [-1, 0, 1
  184. ], "flip_code should be a value in [-1, 0, 1]"
  185. self.flip_code = flip_code
  186. def __call__(self, img):
  187. if random.randint(0, 1) == 1:
  188. return cv2.flip(img, self.flip_code)
  189. else:
  190. return img
  191. class AutoAugment(object):
  192. def __init__(self):
  193. self.policy = ImageNetPolicy()
  194. def __call__(self, img):
  195. from PIL import Image
  196. img = np.ascontiguousarray(img)
  197. img = Image.fromarray(img)
  198. img = self.policy(img)
  199. img = np.asarray(img)
  200. class NormalizeImage(object):
  201. """ normalize image such as substract mean, divide std
  202. """
  203. def __init__(self,
  204. scale=None,
  205. mean=None,
  206. std=None,
  207. order='chw',
  208. output_fp16=False,
  209. channel_num=3):
  210. if isinstance(scale, str):
  211. scale = eval(scale)
  212. assert channel_num in [
  213. 3, 4
  214. ], "channel number of input image should be set to 3 or 4."
  215. self.channel_num = channel_num
  216. self.output_dtype = 'float16' if output_fp16 else 'float32'
  217. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  218. self.order = order
  219. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  220. std = std if std is not None else [0.229, 0.224, 0.225]
  221. shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
  222. self.mean = np.array(mean).reshape(shape).astype('float32')
  223. self.std = np.array(std).reshape(shape).astype('float32')
  224. def __call__(self, img):
  225. from PIL import Image
  226. if isinstance(img, Image.Image):
  227. img = np.array(img)
  228. assert isinstance(img,
  229. np.ndarray), "invalid input 'img' in NormalizeImage"
  230. img = (img.astype('float32') * self.scale - self.mean) / self.std
  231. if self.channel_num == 4:
  232. img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
  233. img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
  234. pad_zeros = np.zeros(
  235. (1, img_h, img_w)) if self.order == 'chw' else np.zeros(
  236. (img_h, img_w, 1))
  237. img = (np.concatenate(
  238. (img, pad_zeros), axis=0)
  239. if self.order == 'chw' else np.concatenate(
  240. (img, pad_zeros), axis=2))
  241. return img.astype(self.output_dtype)
  242. class ToCHWImage(object):
  243. """ convert hwc image to chw image
  244. """
  245. def __init__(self):
  246. pass
  247. def __call__(self, img):
  248. from PIL import Image
  249. if isinstance(img, Image.Image):
  250. img = np.array(img)
  251. return img.transpose((2, 0, 1))
  252. class AugMix(object):
  253. """ Perform AugMix augmentation and compute mixture.
  254. """
  255. def __init__(self,
  256. prob=0.5,
  257. aug_prob_coeff=0.1,
  258. mixture_width=3,
  259. mixture_depth=1,
  260. aug_severity=1):
  261. """
  262. Args:
  263. prob: Probability of taking augmix
  264. aug_prob_coeff: Probability distribution coefficients.
  265. mixture_width: Number of augmentation chains to mix per augmented example.
  266. mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
  267. aug_severity: Severity of underlying augmentation operators (between 1 to 10).
  268. """
  269. # fmt: off
  270. self.prob = prob
  271. self.aug_prob_coeff = aug_prob_coeff
  272. self.mixture_width = mixture_width
  273. self.mixture_depth = mixture_depth
  274. self.aug_severity = aug_severity
  275. self.augmentations = augmentations
  276. # fmt: on
  277. def __call__(self, image):
  278. """Perform AugMix augmentations and compute mixture.
  279. Returns:
  280. mixed: Augmented and mixed image.
  281. """
  282. if random.random() > self.prob:
  283. # Avoid the warning: the given NumPy array is not writeable
  284. return np.asarray(image).copy()
  285. ws = np.float32(
  286. np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
  287. m = np.float32(
  288. np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
  289. # image = Image.fromarray(image)
  290. mix = np.zeros([image.shape[1], image.shape[0], 3])
  291. for i in range(self.mixture_width):
  292. image_aug = image.copy()
  293. image_aug = Image.fromarray(image_aug)
  294. depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(
  295. 1, 4)
  296. for _ in range(depth):
  297. op = np.random.choice(self.augmentations)
  298. image_aug = op(image_aug, self.aug_severity)
  299. mix += ws[i] * np.asarray(image_aug)
  300. mixed = (1 - m) * image + m * mix
  301. return mixed.astype(np.uint8)
  302. class ColorJitter(RawColorJitter):
  303. """ColorJitter.
  304. """
  305. def __init__(self, *args, **kwargs):
  306. super().__init__(*args, **kwargs)
  307. def __call__(self, img):
  308. if not isinstance(img, Image.Image):
  309. img = np.ascontiguousarray(img)
  310. img = Image.fromarray(img)
  311. img = super()._apply_image(img)
  312. if isinstance(img, Image.Image):
  313. img = np.asarray(img)
  314. return img