| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
- from functools import partial
- import six
- import math
- import random
- import cv2
- import numpy as np
- from PIL import Image
- from paddle.vision.transforms import ColorJitter as RawColorJitter
- from .autoaugment import ImageNetPolicy
- from .functional import augmentations
- from paddlex.ppcls.utils import logger
- class UnifiedResize(object):
- def __init__(self, interpolation=None, backend="cv2"):
- _cv2_interp_from_str = {
- 'nearest': cv2.INTER_NEAREST,
- 'bilinear': cv2.INTER_LINEAR,
- 'area': cv2.INTER_AREA,
- 'bicubic': cv2.INTER_CUBIC,
- 'lanczos': cv2.INTER_LANCZOS4
- }
- _pil_interp_from_str = {
- 'nearest': Image.NEAREST,
- 'bilinear': Image.BILINEAR,
- 'bicubic': Image.BICUBIC,
- 'box': Image.BOX,
- 'lanczos': Image.LANCZOS,
- 'hamming': Image.HAMMING
- }
- def _pil_resize(src, size, resample):
- pil_img = Image.fromarray(src)
- pil_img = pil_img.resize(size, resample)
- return np.asarray(pil_img)
- if backend.lower() == "cv2":
- if isinstance(interpolation, str):
- interpolation = _cv2_interp_from_str[interpolation.lower()]
- # compatible with opencv < version 4.4.0
- elif interpolation is None:
- interpolation = cv2.INTER_LINEAR
- self.resize_func = partial(cv2.resize, interpolation=interpolation)
- elif backend.lower() == "pil":
- if isinstance(interpolation, str):
- interpolation = _pil_interp_from_str[interpolation.lower()]
- self.resize_func = partial(_pil_resize, resample=interpolation)
- else:
- logger.warning(
- f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
- )
- self.resize_func = cv2.resize
- def __call__(self, src, size):
- return self.resize_func(src, size)
- class OperatorParamError(ValueError):
- """ OperatorParamError
- """
- pass
- class DecodeImage(object):
- """ decode image """
- def __init__(self, to_rgb=True, to_np=False, channel_first=False):
- self.to_rgb = to_rgb
- self.to_np = to_np # to numpy
- self.channel_first = channel_first # only enabled when to_np is True
- def __call__(self, img):
- if six.PY2:
- assert type(img) is str and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- else:
- assert type(img) is bytes and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- data = np.frombuffer(img, dtype='uint8')
- img = cv2.imdecode(data, 1)
- if self.to_rgb:
- assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
- img.shape)
- img = img[:, :, ::-1]
- if self.channel_first:
- img = img.transpose((2, 0, 1))
- return img
- class ResizeImage(object):
- """ resize image """
- def __init__(self,
- size=None,
- resize_short=None,
- interpolation=None,
- backend="cv2"):
- if resize_short is not None and resize_short > 0:
- self.resize_short = resize_short
- self.w = None
- self.h = None
- elif size is not None:
- self.resize_short = None
- self.w = size if type(size) is int else size[0]
- self.h = size if type(size) is int else size[1]
- else:
- raise OperatorParamError("invalid params for ReisizeImage for '\
- 'both 'size' and 'resize_short' are None")
- self._resize_func = UnifiedResize(
- interpolation=interpolation, backend=backend)
- def __call__(self, img):
- img_h, img_w = img.shape[:2]
- if self.resize_short is not None:
- percent = float(self.resize_short) / min(img_w, img_h)
- w = int(round(img_w * percent))
- h = int(round(img_h * percent))
- else:
- w = self.w
- h = self.h
- return self._resize_func(img, (w, h))
- class CropImage(object):
- """ crop image """
- def __init__(self, size):
- if type(size) is int:
- self.size = (size, size)
- else:
- self.size = size # (h, w)
- def __call__(self, img):
- w, h = self.size
- img_h, img_w = img.shape[:2]
- w_start = (img_w - w) // 2
- h_start = (img_h - h) // 2
- w_end = w_start + w
- h_end = h_start + h
- return img[h_start:h_end, w_start:w_end, :]
- class RandCropImage(object):
- """ random crop image """
- def __init__(self,
- size,
- scale=None,
- ratio=None,
- interpolation=None,
- backend="cv2"):
- if type(size) is int:
- self.size = (size, size) # (h, w)
- else:
- self.size = size
- self.scale = [0.08, 1.0] if scale is None else scale
- self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
- self._resize_func = UnifiedResize(
- interpolation=interpolation, backend=backend)
- def __call__(self, img):
- size = self.size
- scale = self.scale
- ratio = self.ratio
- aspect_ratio = math.sqrt(random.uniform(*ratio))
- w = 1. * aspect_ratio
- h = 1. / aspect_ratio
- img_h, img_w = img.shape[:2]
- bound = min((float(img_w) / img_h) / (w**2),
- (float(img_h) / img_w) / (h**2))
- scale_max = min(scale[1], bound)
- scale_min = min(scale[0], bound)
- target_area = img_w * img_h * random.uniform(scale_min, scale_max)
- target_size = math.sqrt(target_area)
- w = int(target_size * w)
- h = int(target_size * h)
- i = random.randint(0, img_w - w)
- j = random.randint(0, img_h - h)
- img = img[j:j + h, i:i + w, :]
- return self._resize_func(img, size)
- class RandFlipImage(object):
- """ random flip image
- flip_code:
- 1: Flipped Horizontally
- 0: Flipped Vertically
- -1: Flipped Horizontally & Vertically
- """
- def __init__(self, flip_code=1):
- assert flip_code in [-1, 0, 1
- ], "flip_code should be a value in [-1, 0, 1]"
- self.flip_code = flip_code
- def __call__(self, img):
- if random.randint(0, 1) == 1:
- return cv2.flip(img, self.flip_code)
- else:
- return img
- class AutoAugment(object):
- def __init__(self):
- self.policy = ImageNetPolicy()
- def __call__(self, img):
- from PIL import Image
- img = np.ascontiguousarray(img)
- img = Image.fromarray(img)
- img = self.policy(img)
- img = np.asarray(img)
- class NormalizeImage(object):
- """ normalize image such as substract mean, divide std
- """
- def __init__(self,
- scale=None,
- mean=None,
- std=None,
- order='chw',
- output_fp16=False,
- channel_num=3):
- if isinstance(scale, str):
- scale = eval(scale)
- assert channel_num in [
- 3, 4
- ], "channel number of input image should be set to 3 or 4."
- self.channel_num = channel_num
- self.output_dtype = 'float16' if output_fp16 else 'float32'
- self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
- self.order = order
- mean = mean if mean is not None else [0.485, 0.456, 0.406]
- std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
- def __call__(self, img):
- from PIL import Image
- if isinstance(img, Image.Image):
- img = np.array(img)
- assert isinstance(img,
- np.ndarray), "invalid input 'img' in NormalizeImage"
- img = (img.astype('float32') * self.scale - self.mean) / self.std
- if self.channel_num == 4:
- img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
- img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
- pad_zeros = np.zeros(
- (1, img_h, img_w)) if self.order == 'chw' else np.zeros(
- (img_h, img_w, 1))
- img = (np.concatenate(
- (img, pad_zeros), axis=0)
- if self.order == 'chw' else np.concatenate(
- (img, pad_zeros), axis=2))
- return img.astype(self.output_dtype)
- class ToCHWImage(object):
- """ convert hwc image to chw image
- """
- def __init__(self):
- pass
- def __call__(self, img):
- from PIL import Image
- if isinstance(img, Image.Image):
- img = np.array(img)
- return img.transpose((2, 0, 1))
- class AugMix(object):
- """ Perform AugMix augmentation and compute mixture.
- """
- def __init__(self,
- prob=0.5,
- aug_prob_coeff=0.1,
- mixture_width=3,
- mixture_depth=1,
- aug_severity=1):
- """
- Args:
- prob: Probability of taking augmix
- aug_prob_coeff: Probability distribution coefficients.
- mixture_width: Number of augmentation chains to mix per augmented example.
- mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
- aug_severity: Severity of underlying augmentation operators (between 1 to 10).
- """
- # fmt: off
- self.prob = prob
- self.aug_prob_coeff = aug_prob_coeff
- self.mixture_width = mixture_width
- self.mixture_depth = mixture_depth
- self.aug_severity = aug_severity
- self.augmentations = augmentations
- # fmt: on
- def __call__(self, image):
- """Perform AugMix augmentations and compute mixture.
- Returns:
- mixed: Augmented and mixed image.
- """
- if random.random() > self.prob:
- # Avoid the warning: the given NumPy array is not writeable
- return np.asarray(image).copy()
- ws = np.float32(
- np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
- m = np.float32(
- np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
- # image = Image.fromarray(image)
- mix = np.zeros([image.shape[1], image.shape[0], 3])
- for i in range(self.mixture_width):
- image_aug = image.copy()
- image_aug = Image.fromarray(image_aug)
- depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(
- 1, 4)
- for _ in range(depth):
- op = np.random.choice(self.augmentations)
- image_aug = op(image_aug, self.aug_severity)
- mix += ws[i] * np.asarray(image_aug)
- mixed = (1 - m) * image + m * mix
- return mixed.astype(np.uint8)
- class ColorJitter(RawColorJitter):
- """ColorJitter.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- def __call__(self, img):
- if not isinstance(img, Image.Image):
- img = np.ascontiguousarray(img)
- img = Image.fromarray(img)
- img = super()._apply_image(img)
- if isinstance(img, Image.Image):
- img = np.asarray(img)
- return img
|