repvgg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/DingXiaoH/RepVGG
  15. import paddle.nn as nn
  16. import paddle
  17. import numpy as np
  18. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  19. MODEL_URLS = {
  20. "RepVGG_A0":
  21. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A0_pretrained.pdparams",
  22. "RepVGG_A1":
  23. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A1_pretrained.pdparams",
  24. "RepVGG_A2":
  25. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A2_pretrained.pdparams",
  26. "RepVGG_B0":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B0_pretrained.pdparams",
  28. "RepVGG_B1":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1_pretrained.pdparams",
  30. "RepVGG_B2":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2_pretrained.pdparams",
  32. "RepVGG_B3":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3_pretrained.pdparams",
  34. "RepVGG_B1g2":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g2_pretrained.pdparams",
  36. "RepVGG_B1g4":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g4_pretrained.pdparams",
  38. "RepVGG_B2g2":
  39. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g2_pretrained.pdparams",
  40. "RepVGG_B2g4":
  41. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams",
  42. "RepVGG_B3g2":
  43. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g2_pretrained.pdparams",
  44. "RepVGG_B3g4":
  45. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams",
  46. }
  47. __all__ = list(MODEL_URLS.keys())
  48. optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
  49. g2_map = {l: 2 for l in optional_groupwise_layers}
  50. g4_map = {l: 4 for l in optional_groupwise_layers}
  51. class ConvBN(nn.Layer):
  52. def __init__(self,
  53. in_channels,
  54. out_channels,
  55. kernel_size,
  56. stride,
  57. padding,
  58. groups=1):
  59. super(ConvBN, self).__init__()
  60. self.conv = nn.Conv2D(
  61. in_channels=in_channels,
  62. out_channels=out_channels,
  63. kernel_size=kernel_size,
  64. stride=stride,
  65. padding=padding,
  66. groups=groups,
  67. bias_attr=False)
  68. self.bn = nn.BatchNorm2D(num_features=out_channels)
  69. def forward(self, x):
  70. y = self.conv(x)
  71. y = self.bn(y)
  72. return y
  73. class RepVGGBlock(nn.Layer):
  74. def __init__(self,
  75. in_channels,
  76. out_channels,
  77. kernel_size,
  78. stride=1,
  79. padding=0,
  80. dilation=1,
  81. groups=1,
  82. padding_mode='zeros'):
  83. super(RepVGGBlock, self).__init__()
  84. self.in_channels = in_channels
  85. self.out_channels = out_channels
  86. self.kernel_size = kernel_size
  87. self.stride = stride
  88. self.padding = padding
  89. self.dilation = dilation
  90. self.groups = groups
  91. self.padding_mode = padding_mode
  92. assert kernel_size == 3
  93. assert padding == 1
  94. padding_11 = padding - kernel_size // 2
  95. self.nonlinearity = nn.ReLU()
  96. self.rbr_identity = nn.BatchNorm2D(
  97. num_features=in_channels
  98. ) if out_channels == in_channels and stride == 1 else None
  99. self.rbr_dense = ConvBN(
  100. in_channels=in_channels,
  101. out_channels=out_channels,
  102. kernel_size=kernel_size,
  103. stride=stride,
  104. padding=padding,
  105. groups=groups)
  106. self.rbr_1x1 = ConvBN(
  107. in_channels=in_channels,
  108. out_channels=out_channels,
  109. kernel_size=1,
  110. stride=stride,
  111. padding=padding_11,
  112. groups=groups)
  113. def forward(self, inputs):
  114. if not self.training:
  115. return self.nonlinearity(self.rbr_reparam(inputs))
  116. if self.rbr_identity is None:
  117. id_out = 0
  118. else:
  119. id_out = self.rbr_identity(inputs)
  120. return self.nonlinearity(
  121. self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
  122. def eval(self):
  123. if not hasattr(self, 'rbr_reparam'):
  124. self.rbr_reparam = nn.Conv2D(
  125. in_channels=self.in_channels,
  126. out_channels=self.out_channels,
  127. kernel_size=self.kernel_size,
  128. stride=self.stride,
  129. padding=self.padding,
  130. dilation=self.dilation,
  131. groups=self.groups,
  132. padding_mode=self.padding_mode)
  133. self.training = False
  134. kernel, bias = self.get_equivalent_kernel_bias()
  135. self.rbr_reparam.weight.set_value(kernel)
  136. self.rbr_reparam.bias.set_value(bias)
  137. for layer in self.sublayers():
  138. layer.eval()
  139. def get_equivalent_kernel_bias(self):
  140. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  141. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  142. kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  143. return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  144. kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  145. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  146. if kernel1x1 is None:
  147. return 0
  148. else:
  149. return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  150. def _fuse_bn_tensor(self, branch):
  151. if branch is None:
  152. return 0, 0
  153. if isinstance(branch, ConvBN):
  154. kernel = branch.conv.weight
  155. running_mean = branch.bn._mean
  156. running_var = branch.bn._variance
  157. gamma = branch.bn.weight
  158. beta = branch.bn.bias
  159. eps = branch.bn._epsilon
  160. else:
  161. assert isinstance(branch, nn.BatchNorm2D)
  162. if not hasattr(self, 'id_tensor'):
  163. input_dim = self.in_channels // self.groups
  164. kernel_value = np.zeros(
  165. (self.in_channels, input_dim, 3, 3), dtype=np.float32)
  166. for i in range(self.in_channels):
  167. kernel_value[i, i % input_dim, 1, 1] = 1
  168. self.id_tensor = paddle.to_tensor(kernel_value)
  169. kernel = self.id_tensor
  170. running_mean = branch._mean
  171. running_var = branch._variance
  172. gamma = branch.weight
  173. beta = branch.bias
  174. eps = branch._epsilon
  175. std = (running_var + eps).sqrt()
  176. t = (gamma / std).reshape((-1, 1, 1, 1))
  177. return kernel * t, beta - running_mean * gamma / std
  178. class RepVGG(nn.Layer):
  179. def __init__(self,
  180. num_blocks,
  181. width_multiplier=None,
  182. override_groups_map=None,
  183. class_num=1000):
  184. super(RepVGG, self).__init__()
  185. assert len(width_multiplier) == 4
  186. self.override_groups_map = override_groups_map or dict()
  187. assert 0 not in self.override_groups_map
  188. self.in_planes = min(64, int(64 * width_multiplier[0]))
  189. self.stage0 = RepVGGBlock(
  190. in_channels=3,
  191. out_channels=self.in_planes,
  192. kernel_size=3,
  193. stride=2,
  194. padding=1)
  195. self.cur_layer_idx = 1
  196. self.stage1 = self._make_stage(
  197. int(64 * width_multiplier[0]), num_blocks[0], stride=2)
  198. self.stage2 = self._make_stage(
  199. int(128 * width_multiplier[1]), num_blocks[1], stride=2)
  200. self.stage3 = self._make_stage(
  201. int(256 * width_multiplier[2]), num_blocks[2], stride=2)
  202. self.stage4 = self._make_stage(
  203. int(512 * width_multiplier[3]), num_blocks[3], stride=2)
  204. self.gap = nn.AdaptiveAvgPool2D(output_size=1)
  205. self.linear = nn.Linear(int(512 * width_multiplier[3]), class_num)
  206. def _make_stage(self, planes, num_blocks, stride):
  207. strides = [stride] + [1] * (num_blocks - 1)
  208. blocks = []
  209. for stride in strides:
  210. cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
  211. blocks.append(
  212. RepVGGBlock(
  213. in_channels=self.in_planes,
  214. out_channels=planes,
  215. kernel_size=3,
  216. stride=stride,
  217. padding=1,
  218. groups=cur_groups))
  219. self.in_planes = planes
  220. self.cur_layer_idx += 1
  221. return nn.Sequential(*blocks)
  222. def eval(self):
  223. self.training = False
  224. for layer in self.sublayers():
  225. layer.training = False
  226. layer.eval()
  227. def forward(self, x):
  228. out = self.stage0(x)
  229. out = self.stage1(out)
  230. out = self.stage2(out)
  231. out = self.stage3(out)
  232. out = self.stage4(out)
  233. out = self.gap(out)
  234. out = paddle.flatten(out, start_axis=1)
  235. out = self.linear(out)
  236. return out
  237. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  238. if pretrained is False:
  239. pass
  240. elif pretrained is True:
  241. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  242. elif isinstance(pretrained, str):
  243. load_dygraph_pretrain(model, pretrained)
  244. else:
  245. raise RuntimeError(
  246. "pretrained type is not available. Please use `string` or `boolean` type."
  247. )
  248. def RepVGG_A0(pretrained=False, use_ssld=False, **kwargs):
  249. model = RepVGG(
  250. num_blocks=[2, 4, 14, 1],
  251. width_multiplier=[0.75, 0.75, 0.75, 2.5],
  252. override_groups_map=None,
  253. **kwargs)
  254. _load_pretrained(
  255. pretrained, model, MODEL_URLS["RepVGG_A0"], use_ssld=use_ssld)
  256. return model
  257. def RepVGG_A1(pretrained=False, use_ssld=False, **kwargs):
  258. model = RepVGG(
  259. num_blocks=[2, 4, 14, 1],
  260. width_multiplier=[1, 1, 1, 2.5],
  261. override_groups_map=None,
  262. **kwargs)
  263. _load_pretrained(
  264. pretrained, model, MODEL_URLS["RepVGG_A1"], use_ssld=use_ssld)
  265. return model
  266. def RepVGG_A2(pretrained=False, use_ssld=False, **kwargs):
  267. model = RepVGG(
  268. num_blocks=[2, 4, 14, 1],
  269. width_multiplier=[1.5, 1.5, 1.5, 2.75],
  270. override_groups_map=None,
  271. **kwargs)
  272. _load_pretrained(
  273. pretrained, model, MODEL_URLS["RepVGG_A2"], use_ssld=use_ssld)
  274. return model
  275. def RepVGG_B0(pretrained=False, use_ssld=False, **kwargs):
  276. model = RepVGG(
  277. num_blocks=[4, 6, 16, 1],
  278. width_multiplier=[1, 1, 1, 2.5],
  279. override_groups_map=None,
  280. **kwargs)
  281. _load_pretrained(
  282. pretrained, model, MODEL_URLS["RepVGG_B0"], use_ssld=use_ssld)
  283. return model
  284. def RepVGG_B1(pretrained=False, use_ssld=False, **kwargs):
  285. model = RepVGG(
  286. num_blocks=[4, 6, 16, 1],
  287. width_multiplier=[2, 2, 2, 4],
  288. override_groups_map=None,
  289. **kwargs)
  290. _load_pretrained(
  291. pretrained, model, MODEL_URLS["RepVGG_B1"], use_ssld=use_ssld)
  292. return model
  293. def RepVGG_B1g2(pretrained=False, use_ssld=False, **kwargs):
  294. model = RepVGG(
  295. num_blocks=[4, 6, 16, 1],
  296. width_multiplier=[2, 2, 2, 4],
  297. override_groups_map=g2_map,
  298. **kwargs)
  299. _load_pretrained(
  300. pretrained, model, MODEL_URLS["RepVGG_B1g2"], use_ssld=use_ssld)
  301. return model
  302. def RepVGG_B1g4(pretrained=False, use_ssld=False, **kwargs):
  303. model = RepVGG(
  304. num_blocks=[4, 6, 16, 1],
  305. width_multiplier=[2, 2, 2, 4],
  306. override_groups_map=g4_map,
  307. **kwargs)
  308. _load_pretrained(
  309. pretrained, model, MODEL_URLS["RepVGG_B1g4"], use_ssld=use_ssld)
  310. return model
  311. def RepVGG_B2(pretrained=False, use_ssld=False, **kwargs):
  312. model = RepVGG(
  313. num_blocks=[4, 6, 16, 1],
  314. width_multiplier=[2.5, 2.5, 2.5, 5],
  315. override_groups_map=None,
  316. **kwargs)
  317. _load_pretrained(
  318. pretrained, model, MODEL_URLS["RepVGG_B2"], use_ssld=use_ssld)
  319. return model
  320. def RepVGG_B2g2(pretrained=False, use_ssld=False, **kwargs):
  321. model = RepVGG(
  322. num_blocks=[4, 6, 16, 1],
  323. width_multiplier=[2.5, 2.5, 2.5, 5],
  324. override_groups_map=g2_map,
  325. **kwargs)
  326. _load_pretrained(
  327. pretrained, model, MODEL_URLS["RepVGG_B2g2"], use_ssld=use_ssld)
  328. return model
  329. def RepVGG_B2g4(pretrained=False, use_ssld=False, **kwargs):
  330. model = RepVGG(
  331. num_blocks=[4, 6, 16, 1],
  332. width_multiplier=[2.5, 2.5, 2.5, 5],
  333. override_groups_map=g4_map,
  334. **kwargs)
  335. _load_pretrained(
  336. pretrained, model, MODEL_URLS["RepVGG_B2g4"], use_ssld=use_ssld)
  337. return model
  338. def RepVGG_B3(pretrained=False, use_ssld=False, **kwargs):
  339. model = RepVGG(
  340. num_blocks=[4, 6, 16, 1],
  341. width_multiplier=[3, 3, 3, 5],
  342. override_groups_map=None,
  343. **kwargs)
  344. _load_pretrained(
  345. pretrained, model, MODEL_URLS["RepVGG_B3"], use_ssld=use_ssld)
  346. return model
  347. def RepVGG_B3g2(pretrained=False, use_ssld=False, **kwargs):
  348. model = RepVGG(
  349. num_blocks=[4, 6, 16, 1],
  350. width_multiplier=[3, 3, 3, 5],
  351. override_groups_map=g2_map,
  352. **kwargs)
  353. _load_pretrained(
  354. pretrained, model, MODEL_URLS["RepVGG_B3g2"], use_ssld=use_ssld)
  355. return model
  356. def RepVGG_B3g4(pretrained=False, use_ssld=False, **kwargs):
  357. model = RepVGG(
  358. num_blocks=[4, 6, 16, 1],
  359. width_multiplier=[3, 3, 3, 5],
  360. override_groups_map=g4_map,
  361. **kwargs)
  362. _load_pretrained(
  363. pretrained, model, MODEL_URLS["RepVGG_B3g4"], use_ssld=use_ssld)
  364. return model