| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Code was based on https://github.com/DingXiaoH/RepVGG
- import paddle.nn as nn
- import paddle
- import numpy as np
- from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
- MODEL_URLS = {
- "RepVGG_A0":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A0_pretrained.pdparams",
- "RepVGG_A1":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A1_pretrained.pdparams",
- "RepVGG_A2":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A2_pretrained.pdparams",
- "RepVGG_B0":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B0_pretrained.pdparams",
- "RepVGG_B1":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1_pretrained.pdparams",
- "RepVGG_B2":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2_pretrained.pdparams",
- "RepVGG_B3":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3_pretrained.pdparams",
- "RepVGG_B1g2":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g2_pretrained.pdparams",
- "RepVGG_B1g4":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g4_pretrained.pdparams",
- "RepVGG_B2g2":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g2_pretrained.pdparams",
- "RepVGG_B2g4":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams",
- "RepVGG_B3g2":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g2_pretrained.pdparams",
- "RepVGG_B3g4":
- "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams",
- }
- __all__ = list(MODEL_URLS.keys())
- optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
- g2_map = {l: 2 for l in optional_groupwise_layers}
- g4_map = {l: 4 for l in optional_groupwise_layers}
- class ConvBN(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups=1):
- super(ConvBN, self).__init__()
- self.conv = nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- groups=groups,
- bias_attr=False)
- self.bn = nn.BatchNorm2D(num_features=out_channels)
- def forward(self, x):
- y = self.conv(x)
- y = self.bn(y)
- return y
- class RepVGGBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- padding_mode='zeros'):
- super(RepVGGBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.padding = padding
- self.dilation = dilation
- self.groups = groups
- self.padding_mode = padding_mode
- assert kernel_size == 3
- assert padding == 1
- padding_11 = padding - kernel_size // 2
- self.nonlinearity = nn.ReLU()
- self.rbr_identity = nn.BatchNorm2D(
- num_features=in_channels
- ) if out_channels == in_channels and stride == 1 else None
- self.rbr_dense = ConvBN(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- groups=groups)
- self.rbr_1x1 = ConvBN(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=stride,
- padding=padding_11,
- groups=groups)
- def forward(self, inputs):
- if not self.training:
- return self.nonlinearity(self.rbr_reparam(inputs))
- if self.rbr_identity is None:
- id_out = 0
- else:
- id_out = self.rbr_identity(inputs)
- return self.nonlinearity(
- self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
- def eval(self):
- if not hasattr(self, 'rbr_reparam'):
- self.rbr_reparam = nn.Conv2D(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.padding,
- dilation=self.dilation,
- groups=self.groups,
- padding_mode=self.padding_mode)
- self.training = False
- kernel, bias = self.get_equivalent_kernel_bias()
- self.rbr_reparam.weight.set_value(kernel)
- self.rbr_reparam.bias.set_value(bias)
- for layer in self.sublayers():
- layer.eval()
- def get_equivalent_kernel_bias(self):
- kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
- kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
- kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
- return kernel3x3 + self._pad_1x1_to_3x3_tensor(
- kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
- def _pad_1x1_to_3x3_tensor(self, kernel1x1):
- if kernel1x1 is None:
- return 0
- else:
- return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
- def _fuse_bn_tensor(self, branch):
- if branch is None:
- return 0, 0
- if isinstance(branch, ConvBN):
- kernel = branch.conv.weight
- running_mean = branch.bn._mean
- running_var = branch.bn._variance
- gamma = branch.bn.weight
- beta = branch.bn.bias
- eps = branch.bn._epsilon
- else:
- assert isinstance(branch, nn.BatchNorm2D)
- if not hasattr(self, 'id_tensor'):
- input_dim = self.in_channels // self.groups
- kernel_value = np.zeros(
- (self.in_channels, input_dim, 3, 3), dtype=np.float32)
- for i in range(self.in_channels):
- kernel_value[i, i % input_dim, 1, 1] = 1
- self.id_tensor = paddle.to_tensor(kernel_value)
- kernel = self.id_tensor
- running_mean = branch._mean
- running_var = branch._variance
- gamma = branch.weight
- beta = branch.bias
- eps = branch._epsilon
- std = (running_var + eps).sqrt()
- t = (gamma / std).reshape((-1, 1, 1, 1))
- return kernel * t, beta - running_mean * gamma / std
- class RepVGG(nn.Layer):
- def __init__(self,
- num_blocks,
- width_multiplier=None,
- override_groups_map=None,
- class_num=1000):
- super(RepVGG, self).__init__()
- assert len(width_multiplier) == 4
- self.override_groups_map = override_groups_map or dict()
- assert 0 not in self.override_groups_map
- self.in_planes = min(64, int(64 * width_multiplier[0]))
- self.stage0 = RepVGGBlock(
- in_channels=3,
- out_channels=self.in_planes,
- kernel_size=3,
- stride=2,
- padding=1)
- self.cur_layer_idx = 1
- self.stage1 = self._make_stage(
- int(64 * width_multiplier[0]), num_blocks[0], stride=2)
- self.stage2 = self._make_stage(
- int(128 * width_multiplier[1]), num_blocks[1], stride=2)
- self.stage3 = self._make_stage(
- int(256 * width_multiplier[2]), num_blocks[2], stride=2)
- self.stage4 = self._make_stage(
- int(512 * width_multiplier[3]), num_blocks[3], stride=2)
- self.gap = nn.AdaptiveAvgPool2D(output_size=1)
- self.linear = nn.Linear(int(512 * width_multiplier[3]), class_num)
- def _make_stage(self, planes, num_blocks, stride):
- strides = [stride] + [1] * (num_blocks - 1)
- blocks = []
- for stride in strides:
- cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
- blocks.append(
- RepVGGBlock(
- in_channels=self.in_planes,
- out_channels=planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- groups=cur_groups))
- self.in_planes = planes
- self.cur_layer_idx += 1
- return nn.Sequential(*blocks)
- def eval(self):
- self.training = False
- for layer in self.sublayers():
- layer.training = False
- layer.eval()
- def forward(self, x):
- out = self.stage0(x)
- out = self.stage1(out)
- out = self.stage2(out)
- out = self.stage3(out)
- out = self.stage4(out)
- out = self.gap(out)
- out = paddle.flatten(out, start_axis=1)
- out = self.linear(out)
- return out
- def _load_pretrained(pretrained, model, model_url, use_ssld=False):
- if pretrained is False:
- pass
- elif pretrained is True:
- load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
- elif isinstance(pretrained, str):
- load_dygraph_pretrain(model, pretrained)
- else:
- raise RuntimeError(
- "pretrained type is not available. Please use `string` or `boolean` type."
- )
- def RepVGG_A0(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[2, 4, 14, 1],
- width_multiplier=[0.75, 0.75, 0.75, 2.5],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_A0"], use_ssld=use_ssld)
- return model
- def RepVGG_A1(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[2, 4, 14, 1],
- width_multiplier=[1, 1, 1, 2.5],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_A1"], use_ssld=use_ssld)
- return model
- def RepVGG_A2(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[2, 4, 14, 1],
- width_multiplier=[1.5, 1.5, 1.5, 2.75],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_A2"], use_ssld=use_ssld)
- return model
- def RepVGG_B0(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[1, 1, 1, 2.5],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B0"], use_ssld=use_ssld)
- return model
- def RepVGG_B1(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2, 2, 2, 4],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B1"], use_ssld=use_ssld)
- return model
- def RepVGG_B1g2(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2, 2, 2, 4],
- override_groups_map=g2_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B1g2"], use_ssld=use_ssld)
- return model
- def RepVGG_B1g4(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2, 2, 2, 4],
- override_groups_map=g4_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B1g4"], use_ssld=use_ssld)
- return model
- def RepVGG_B2(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2.5, 2.5, 2.5, 5],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B2"], use_ssld=use_ssld)
- return model
- def RepVGG_B2g2(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2.5, 2.5, 2.5, 5],
- override_groups_map=g2_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B2g2"], use_ssld=use_ssld)
- return model
- def RepVGG_B2g4(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[2.5, 2.5, 2.5, 5],
- override_groups_map=g4_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B2g4"], use_ssld=use_ssld)
- return model
- def RepVGG_B3(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[3, 3, 3, 5],
- override_groups_map=None,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B3"], use_ssld=use_ssld)
- return model
- def RepVGG_B3g2(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[3, 3, 3, 5],
- override_groups_map=g2_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B3g2"], use_ssld=use_ssld)
- return model
- def RepVGG_B3g4(pretrained=False, use_ssld=False, **kwargs):
- model = RepVGG(
- num_blocks=[4, 6, 16, 1],
- width_multiplier=[3, 3, 3, 5],
- override_groups_map=g4_map,
- **kwargs)
- _load_pretrained(
- pretrained, model, MODEL_URLS["RepVGG_B3g4"], use_ssld=use_ssld)
- return model
|