| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Code was heavily based on https://github.com/rwightman/pytorch-image-models
- import random
- import math
- import re
- from PIL import Image, ImageOps, ImageEnhance, ImageChops
- import PIL
- import numpy as np
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
- _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
- _FILL = (128, 128, 128)
- # This signifies the max integer that the controller RNN could predict for the
- # augmentation scheme.
- _MAX_LEVEL = 10.
- _HPARAMS_DEFAULT = dict(
- translate_const=250,
- img_mean=_FILL, )
- _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
- def _pil_interp(method):
- if method == 'bicubic':
- return Image.BICUBIC
- elif method == 'lanczos':
- return Image.LANCZOS
- elif method == 'hamming':
- return Image.HAMMING
- else:
- # default bilinear, do we want to allow nearest?
- return Image.BILINEAR
- def _interpolation(kwargs):
- interpolation = kwargs.pop('resample', Image.BILINEAR)
- if isinstance(interpolation, (list, tuple)):
- return random.choice(interpolation)
- else:
- return interpolation
- def _check_args_tf(kwargs):
- if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
- kwargs.pop('fillcolor')
- kwargs['resample'] = _interpolation(kwargs)
- def shear_x(img, factor, **kwargs):
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0),
- **kwargs)
- def shear_y(img, factor, **kwargs):
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0),
- **kwargs)
- def translate_x_rel(img, pct, **kwargs):
- pixels = pct * img.size[0]
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
- **kwargs)
- def translate_y_rel(img, pct, **kwargs):
- pixels = pct * img.size[1]
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
- **kwargs)
- def translate_x_abs(img, pixels, **kwargs):
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
- **kwargs)
- def translate_y_abs(img, pixels, **kwargs):
- _check_args_tf(kwargs)
- return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
- **kwargs)
- def rotate(img, degrees, **kwargs):
- _check_args_tf(kwargs)
- if _PIL_VER >= (5, 2):
- return img.rotate(degrees, **kwargs)
- elif _PIL_VER >= (5, 0):
- w, h = img.size
- post_trans = (0, 0)
- rotn_center = (w / 2.0, h / 2.0)
- angle = -math.radians(degrees)
- matrix = [
- round(math.cos(angle), 15),
- round(math.sin(angle), 15),
- 0.0,
- round(-math.sin(angle), 15),
- round(math.cos(angle), 15),
- 0.0,
- ]
- def transform(x, y, matrix):
- (a, b, c, d, e, f) = matrix
- return a * x + b * y + c, d * x + e * y + f
- matrix[2], matrix[5] = transform(-rotn_center[0] - post_trans[0],
- -rotn_center[1] - post_trans[1],
- matrix)
- matrix[2] += rotn_center[0]
- matrix[5] += rotn_center[1]
- return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
- else:
- return img.rotate(degrees, resample=kwargs['resample'])
- def auto_contrast(img, **__):
- return ImageOps.autocontrast(img)
- def invert(img, **__):
- return ImageOps.invert(img)
- def equalize(img, **__):
- return ImageOps.equalize(img)
- def solarize(img, thresh, **__):
- return ImageOps.solarize(img, thresh)
- def solarize_add(img, add, thresh=128, **__):
- lut = []
- for i in range(256):
- if i < thresh:
- lut.append(min(255, i + add))
- else:
- lut.append(i)
- if img.mode in ("L", "RGB"):
- if img.mode == "RGB" and len(lut) == 256:
- lut = lut + lut + lut
- return img.point(lut)
- else:
- return img
- def posterize(img, bits_to_keep, **__):
- if bits_to_keep >= 8:
- return img
- return ImageOps.posterize(img, bits_to_keep)
- def contrast(img, factor, **__):
- return ImageEnhance.Contrast(img).enhance(factor)
- def color(img, factor, **__):
- return ImageEnhance.Color(img).enhance(factor)
- def brightness(img, factor, **__):
- return ImageEnhance.Brightness(img).enhance(factor)
- def sharpness(img, factor, **__):
- return ImageEnhance.Sharpness(img).enhance(factor)
- def _randomly_negate(v):
- """With 50% prob, negate the value"""
- return -v if random.random() > 0.5 else v
- def _rotate_level_to_arg(level, _hparams):
- # range [-30, 30]
- level = (level / _MAX_LEVEL) * 30.
- level = _randomly_negate(level)
- return level,
- def _enhance_level_to_arg(level, _hparams):
- # range [0.1, 1.9]
- return (level / _MAX_LEVEL) * 1.8 + 0.1,
- def _enhance_increasing_level_to_arg(level, _hparams):
- # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
- # range [0.1, 1.9]
- level = (level / _MAX_LEVEL) * .9
- level = 1.0 + _randomly_negate(level)
- return level,
- def _shear_level_to_arg(level, _hparams):
- # range [-0.3, 0.3]
- level = (level / _MAX_LEVEL) * 0.3
- level = _randomly_negate(level)
- return level,
- def _translate_abs_level_to_arg(level, hparams):
- translate_const = hparams['translate_const']
- level = (level / _MAX_LEVEL) * float(translate_const)
- level = _randomly_negate(level)
- return level,
- def _translate_rel_level_to_arg(level, hparams):
- # default range [-0.45, 0.45]
- translate_pct = hparams.get('translate_pct', 0.45)
- level = (level / _MAX_LEVEL) * translate_pct
- level = _randomly_negate(level)
- return level,
- def _posterize_level_to_arg(level, _hparams):
- # As per Tensorflow TPU EfficientNet impl
- # range [0, 4], 'keep 0 up to 4 MSB of original image'
- # intensity/severity of augmentation decreases with level
- return int((level / _MAX_LEVEL) * 4),
- def _posterize_increasing_level_to_arg(level, hparams):
- # As per Tensorflow models research and UDA impl
- # range [4, 0], 'keep 4 down to 0 MSB of original image',
- # intensity/severity of augmentation increases with level
- return 4 - _posterize_level_to_arg(level, hparams)[0],
- def _posterize_original_level_to_arg(level, _hparams):
- # As per original AutoAugment paper description
- # range [4, 8], 'keep 4 up to 8 MSB of image'
- # intensity/severity of augmentation decreases with level
- return int((level / _MAX_LEVEL) * 4) + 4,
- def _solarize_level_to_arg(level, _hparams):
- # range [0, 256]
- # intensity/severity of augmentation decreases with level
- return int((level / _MAX_LEVEL) * 256),
- def _solarize_increasing_level_to_arg(level, _hparams):
- # range [0, 256]
- # intensity/severity of augmentation increases with level
- return 256 - _solarize_level_to_arg(level, _hparams)[0],
- def _solarize_add_level_to_arg(level, _hparams):
- # range [0, 110]
- return int((level / _MAX_LEVEL) * 110),
- LEVEL_TO_ARG = {
- 'AutoContrast': None,
- 'Equalize': None,
- 'Invert': None,
- 'Rotate': _rotate_level_to_arg,
- # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
- 'Posterize': _posterize_level_to_arg,
- 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
- 'PosterizeOriginal': _posterize_original_level_to_arg,
- 'Solarize': _solarize_level_to_arg,
- 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
- 'SolarizeAdd': _solarize_add_level_to_arg,
- 'Color': _enhance_level_to_arg,
- 'ColorIncreasing': _enhance_increasing_level_to_arg,
- 'Contrast': _enhance_level_to_arg,
- 'ContrastIncreasing': _enhance_increasing_level_to_arg,
- 'Brightness': _enhance_level_to_arg,
- 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
- 'Sharpness': _enhance_level_to_arg,
- 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
- 'ShearX': _shear_level_to_arg,
- 'ShearY': _shear_level_to_arg,
- 'TranslateX': _translate_abs_level_to_arg,
- 'TranslateY': _translate_abs_level_to_arg,
- 'TranslateXRel': _translate_rel_level_to_arg,
- 'TranslateYRel': _translate_rel_level_to_arg,
- }
- NAME_TO_OP = {
- 'AutoContrast': auto_contrast,
- 'Equalize': equalize,
- 'Invert': invert,
- 'Rotate': rotate,
- 'Posterize': posterize,
- 'PosterizeIncreasing': posterize,
- 'PosterizeOriginal': posterize,
- 'Solarize': solarize,
- 'SolarizeIncreasing': solarize,
- 'SolarizeAdd': solarize_add,
- 'Color': color,
- 'ColorIncreasing': color,
- 'Contrast': contrast,
- 'ContrastIncreasing': contrast,
- 'Brightness': brightness,
- 'BrightnessIncreasing': brightness,
- 'Sharpness': sharpness,
- 'SharpnessIncreasing': sharpness,
- 'ShearX': shear_x,
- 'ShearY': shear_y,
- 'TranslateX': translate_x_abs,
- 'TranslateY': translate_y_abs,
- 'TranslateXRel': translate_x_rel,
- 'TranslateYRel': translate_y_rel,
- }
- class AugmentOp(object):
- def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
- hparams = hparams or _HPARAMS_DEFAULT
- self.aug_fn = NAME_TO_OP[name]
- self.level_fn = LEVEL_TO_ARG[name]
- self.prob = prob
- self.magnitude = magnitude
- self.hparams = hparams.copy()
- self.kwargs = dict(
- fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
- resample=hparams['interpolation']
- if 'interpolation' in hparams else _RANDOM_INTERPOLATION, )
- # If magnitude_std is > 0, we introduce some randomness
- # in the usually fixed policy and sample magnitude from a normal distribution
- # with mean `magnitude` and std-dev of `magnitude_std`.
- # NOTE This is my own hack, being tested, not in papers or reference impls.
- self.magnitude_std = self.hparams.get('magnitude_std', 0)
- def __call__(self, img):
- if self.prob < 1.0 and random.random() > self.prob:
- return img
- magnitude = self.magnitude
- if self.magnitude_std and self.magnitude_std > 0:
- magnitude = random.gauss(magnitude, self.magnitude_std)
- magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
- level_args = self.level_fn(
- magnitude, self.hparams) if self.level_fn is not None else tuple()
- return self.aug_fn(img, *level_args, **self.kwargs)
- def auto_augment_policy_v0(hparams):
- # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
- policy = [
- [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
- [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
- [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
- [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
- [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
- [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
- [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
- [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
- [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
- [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
- [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
- [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
- [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
- [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
- [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
- [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
- [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
- [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
- [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
- [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
- [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
- [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
- [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)
- ], # This results in black image with Tpu posterize
- [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
- [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
- ]
- pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
- return pc
- def auto_augment_policy_v0r(hparams):
- # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
- # in Google research implementation (number of bits discarded increases with magnitude)
- policy = [
- [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
- [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
- [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
- [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
- [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
- [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
- [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
- [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
- [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
- [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
- [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
- [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
- [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
- [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
- [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
- [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
- [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
- [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
- [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
- [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
- [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
- [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
- [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
- [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
- [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
- ]
- pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
- return pc
- def auto_augment_policy_original(hparams):
- # ImageNet policy from https://arxiv.org/abs/1805.09501
- policy = [
- [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
- [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
- [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
- [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
- [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
- [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
- [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
- [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
- [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
- [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
- [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
- [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
- [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
- [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
- [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
- [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
- [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
- [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
- [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
- [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
- [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
- [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
- [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
- [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
- [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
- ]
- pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
- return pc
- def auto_augment_policy_originalr(hparams):
- # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
- policy = [
- [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
- [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
- [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
- [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
- [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
- [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
- [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
- [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
- [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
- [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
- [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
- [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
- [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
- [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
- [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
- [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
- [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
- [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
- [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
- [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
- [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
- [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
- [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
- [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
- [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
- ]
- pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
- return pc
- def auto_augment_policy(name='v0', hparams=None):
- hparams = hparams or _HPARAMS_DEFAULT
- if name == 'original':
- return auto_augment_policy_original(hparams)
- elif name == 'originalr':
- return auto_augment_policy_originalr(hparams)
- elif name == 'v0':
- return auto_augment_policy_v0(hparams)
- elif name == 'v0r':
- return auto_augment_policy_v0r(hparams)
- else:
- assert False, 'Unknown AA policy (%s)' % name
- class AutoAugment(object):
- def __init__(self, policy):
- self.policy = policy
- def __call__(self, img):
- sub_policy = random.choice(self.policy)
- for op in sub_policy:
- img = op(img)
- return img
- def auto_augment_transform(config_str, hparams):
- """
- Create a AutoAugment transform
- :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
- dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
- The remaining sections, not order sepecific determine
- 'mstd' - float std deviation of magnitude noise applied
- Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
- :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
- :return: A callable Transform Op
- """
- config = config_str.split('-')
- policy_name = config[0]
- config = config[1:]
- for c in config:
- cs = re.split(r'(\d.*)', c)
- if len(cs) < 2:
- continue
- key, val = cs[:2]
- if key == 'mstd':
- # noise param injected via hparams for now
- hparams.setdefault('magnitude_std', float(val))
- else:
- assert False, 'Unknown AutoAugment config section'
- aa_policy = auto_augment_policy(policy_name, hparams=hparams)
- return AutoAugment(aa_policy)
- _RAND_TRANSFORMS = [
- 'AutoContrast',
- 'Equalize',
- 'Invert',
- 'Rotate',
- 'Posterize',
- 'Solarize',
- 'SolarizeAdd',
- 'Color',
- 'Contrast',
- 'Brightness',
- 'Sharpness',
- 'ShearX',
- 'ShearY',
- 'TranslateXRel',
- 'TranslateYRel',
- #'Cutout' # NOTE I've implement this as random erasing separately
- ]
- _RAND_INCREASING_TRANSFORMS = [
- 'AutoContrast',
- 'Equalize',
- 'Invert',
- 'Rotate',
- 'PosterizeIncreasing',
- 'SolarizeIncreasing',
- 'SolarizeAdd',
- 'ColorIncreasing',
- 'ContrastIncreasing',
- 'BrightnessIncreasing',
- 'SharpnessIncreasing',
- 'ShearX',
- 'ShearY',
- 'TranslateXRel',
- 'TranslateYRel',
- #'Cutout' # NOTE I've implement this as random erasing separately
- ]
- # These experimental weights are based loosely on the relative improvements mentioned in paper.
- # They may not result in increased performance, but could likely be tuned to so.
- _RAND_CHOICE_WEIGHTS_0 = {
- 'Rotate': 0.3,
- 'ShearX': 0.2,
- 'ShearY': 0.2,
- 'TranslateXRel': 0.1,
- 'TranslateYRel': 0.1,
- 'Color': .025,
- 'Sharpness': 0.025,
- 'AutoContrast': 0.025,
- 'Solarize': .005,
- 'SolarizeAdd': .005,
- 'Contrast': .005,
- 'Brightness': .005,
- 'Equalize': .005,
- 'Posterize': 0,
- 'Invert': 0,
- }
- def _select_rand_weights(weight_idx=0, transforms=None):
- transforms = transforms or _RAND_TRANSFORMS
- assert weight_idx == 0 # only one set of weights currently
- rand_weights = _RAND_CHOICE_WEIGHTS_0
- probs = [rand_weights[k] for k in transforms]
- probs /= np.sum(probs)
- return probs
- def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
- hparams = hparams or _HPARAMS_DEFAULT
- transforms = transforms or _RAND_TRANSFORMS
- return [
- AugmentOp(
- name, prob=0.5, magnitude=magnitude, hparams=hparams)
- for name in transforms
- ]
- class RandAugment(object):
- def __init__(self, ops, num_layers=2, choice_weights=None):
- self.ops = ops
- self.num_layers = num_layers
- self.choice_weights = choice_weights
- def __call__(self, img):
- # no replacement when using weighted choice
- ops = np.random.choice(
- self.ops,
- self.num_layers,
- replace=self.choice_weights is None,
- p=self.choice_weights)
- for op in ops:
- img = op(img)
- return img
- def rand_augment_transform(config_str, hparams):
- """
- Create a RandAugment transform
- :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
- dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
- sections, not order sepecific determine
- 'm' - integer magnitude of rand augment
- 'n' - integer num layers (number of transform ops selected per image)
- 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
- 'mstd' - float std deviation of magnitude noise applied
- 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
- Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
- 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
- :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
- :return: A callable Transform Op
- """
- magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
- num_layers = 2 # default to 2 ops per image
- weight_idx = None # default to no probability weights for op choice
- transforms = _RAND_TRANSFORMS
- config = config_str.split('-')
- assert config[0] == 'rand'
- config = config[1:]
- for c in config:
- cs = re.split(r'(\d.*)', c)
- if len(cs) < 2:
- continue
- key, val = cs[:2]
- if key == 'mstd':
- # noise param injected via hparams for now
- hparams.setdefault('magnitude_std', float(val))
- elif key == 'inc':
- if bool(val):
- transforms = _RAND_INCREASING_TRANSFORMS
- elif key == 'm':
- magnitude = int(val)
- elif key == 'n':
- num_layers = int(val)
- elif key == 'w':
- weight_idx = int(val)
- else:
- assert False, 'Unknown RandAugment config section'
- ra_ops = rand_augment_ops(
- magnitude=magnitude, hparams=hparams, transforms=transforms)
- choice_weights = None if weight_idx is None else _select_rand_weights(
- weight_idx)
- return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
- _AUGMIX_TRANSFORMS = [
- 'AutoContrast',
- 'ColorIncreasing', # not in paper
- 'ContrastIncreasing', # not in paper
- 'BrightnessIncreasing', # not in paper
- 'SharpnessIncreasing', # not in paper
- 'Equalize',
- 'Rotate',
- 'PosterizeIncreasing',
- 'SolarizeIncreasing',
- 'ShearX',
- 'ShearY',
- 'TranslateXRel',
- 'TranslateYRel',
- ]
- def augmix_ops(magnitude=10, hparams=None, transforms=None):
- hparams = hparams or _HPARAMS_DEFAULT
- transforms = transforms or _AUGMIX_TRANSFORMS
- return [
- AugmentOp(
- name, prob=1.0, magnitude=magnitude, hparams=hparams)
- for name in transforms
- ]
- class AugMixAugment(object):
- """ AugMix Transform
- Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
- From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
- https://arxiv.org/abs/1912.02781
- """
- def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
- self.ops = ops
- self.alpha = alpha
- self.width = width
- self.depth = depth
- self.blended = blended # blended mode is faster but not well tested
- def _calc_blended_weights(self, ws, m):
- ws = ws * m
- cump = 1.
- rws = []
- for w in ws[::-1]:
- alpha = w / cump
- cump *= (1 - alpha)
- rws.append(alpha)
- return np.array(rws[::-1], dtype=np.float32)
- def _apply_blended(self, img, mixing_weights, m):
- # This is my first crack and implementing a slightly faster mixed augmentation. Instead
- # of accumulating the mix for each chain in a Numpy array and then blending with original,
- # it recomputes the blending coefficients and applies one PIL image blend per chain.
- # TODO the results appear in the right ballpark but they differ by more than rounding.
- img_orig = img.copy()
- ws = self._calc_blended_weights(mixing_weights, m)
- for w in ws:
- depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
- ops = np.random.choice(self.ops, depth, replace=True)
- img_aug = img_orig # no ops are in-place, deep copy not necessary
- for op in ops:
- img_aug = op(img_aug)
- img = Image.blend(img, img_aug, w)
- return img
- def _apply_basic(self, img, mixing_weights, m):
- # This is a literal adaptation of the paper/official implementation without normalizations and
- # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
- # typical augmentation transforms, could use a GPU / Kornia implementation.
- img_shape = img.size[0], img.size[1], len(img.getbands())
- mixed = np.zeros(img_shape, dtype=np.float32)
- for mw in mixing_weights:
- depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
- ops = np.random.choice(self.ops, depth, replace=True)
- img_aug = img # no ops are in-place, deep copy not necessary
- for op in ops:
- img_aug = op(img_aug)
- mixed += mw * np.asarray(img_aug, dtype=np.float32)
- np.clip(mixed, 0, 255., out=mixed)
- mixed = Image.fromarray(mixed.astype(np.uint8))
- return Image.blend(img, mixed, m)
- def __call__(self, img):
- mixing_weights = np.float32(
- np.random.dirichlet([self.alpha] * self.width))
- m = np.float32(np.random.beta(self.alpha, self.alpha))
- if self.blended:
- mixed = self._apply_blended(img, mixing_weights, m)
- else:
- mixed = self._apply_basic(img, mixing_weights, m)
- return mixed
- def augment_and_mix_transform(config_str, hparams):
- """ Create AugMix transform
- :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
- dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
- sections, not order sepecific determine
- 'm' - integer magnitude (severity) of augmentation mix (default: 3)
- 'w' - integer width of augmentation chain (default: 3)
- 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
- 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
- 'mstd' - float std deviation of magnitude noise applied (default: 0)
- Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
- :param hparams: Other hparams (kwargs) for the Augmentation transforms
- :return: A callable Transform Op
- """
- magnitude = 3
- width = 3
- depth = -1
- alpha = 1.
- blended = False
- config = config_str.split('-')
- assert config[0] == 'augmix'
- config = config[1:]
- for c in config:
- cs = re.split(r'(\d.*)', c)
- if len(cs) < 2:
- continue
- key, val = cs[:2]
- if key == 'mstd':
- # noise param injected via hparams for now
- hparams.setdefault('magnitude_std', float(val))
- elif key == 'm':
- magnitude = int(val)
- elif key == 'w':
- width = int(val)
- elif key == 'd':
- depth = int(val)
- elif key == 'a':
- alpha = float(val)
- elif key == 'b':
- blended = bool(val)
- else:
- assert False, 'Unknown AugMix config section'
- ops = augmix_ops(magnitude=magnitude, hparams=hparams)
- return AugMixAugment(
- ops, alpha=alpha, width=width, depth=depth, blended=blended)
- class RawTimmAutoAugment(object):
- """TimmAutoAugment API for PaddleClas."""
- def __init__(self,
- config_str="rand-m9-mstd0.5-inc1",
- interpolation="bicubic",
- img_size=224,
- mean=IMAGENET_DEFAULT_MEAN):
- if isinstance(img_size, (tuple, list)):
- img_size_min = min(img_size)
- else:
- img_size_min = img_size
- aa_params = dict(
- translate_const=int(img_size_min * 0.45),
- img_mean=tuple([min(255, round(255 * x)) for x in mean]), )
- if interpolation and interpolation != 'random':
- aa_params['interpolation'] = _pil_interp(interpolation)
- if config_str.startswith('rand'):
- self.augment_func = rand_augment_transform(config_str, aa_params)
- elif config_str.startswith('augmix'):
- aa_params['translate_pct'] = 0.3
- self.augment_func = augment_and_mix_transform(config_str,
- aa_params)
- elif config_str.startswith('auto'):
- self.augment_func = auto_augment_transform(config_str, aa_params)
- else:
- raise Exception(
- "ConfigError: The TimmAutoAugment Op only support RandAugment, AutoAugment, AugMix, and the config_str only starts with \"rand\", \"augmix\", \"auto\"."
- )
- def __call__(self, img):
- return self.augment_func(img)
|