resnet_vc.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  27. MODEL_URLS = {
  28. "ResNet50_vc":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vc_pretrained.pdparams",
  30. }
  31. __all__ = list(MODEL_URLS.keys())
  32. class ConvBNLayer(nn.Layer):
  33. def __init__(self,
  34. num_channels,
  35. num_filters,
  36. filter_size,
  37. stride=1,
  38. groups=1,
  39. act=None,
  40. name=None):
  41. super(ConvBNLayer, self).__init__()
  42. self._conv = Conv2D(
  43. in_channels=num_channels,
  44. out_channels=num_filters,
  45. kernel_size=filter_size,
  46. stride=stride,
  47. padding=(filter_size - 1) // 2,
  48. groups=groups,
  49. weight_attr=ParamAttr(name=name + "_weights"),
  50. bias_attr=False)
  51. if name == "conv1":
  52. bn_name = "bn_" + name
  53. else:
  54. bn_name = "bn" + name[3:]
  55. self._batch_norm = BatchNorm(
  56. num_filters,
  57. act=act,
  58. param_attr=ParamAttr(name=bn_name + '_scale'),
  59. bias_attr=ParamAttr(bn_name + '_offset'),
  60. moving_mean_name=bn_name + '_mean',
  61. moving_variance_name=bn_name + '_variance')
  62. def forward(self, inputs):
  63. y = self._conv(inputs)
  64. y = self._batch_norm(y)
  65. return y
  66. class BottleneckBlock(nn.Layer):
  67. def __init__(self,
  68. num_channels,
  69. num_filters,
  70. stride,
  71. shortcut=True,
  72. name=None):
  73. super(BottleneckBlock, self).__init__()
  74. self.conv0 = ConvBNLayer(
  75. num_channels=num_channels,
  76. num_filters=num_filters,
  77. filter_size=1,
  78. act='relu',
  79. name=name + "_branch2a")
  80. self.conv1 = ConvBNLayer(
  81. num_channels=num_filters,
  82. num_filters=num_filters,
  83. filter_size=3,
  84. stride=stride,
  85. act='relu',
  86. name=name + "_branch2b")
  87. self.conv2 = ConvBNLayer(
  88. num_channels=num_filters,
  89. num_filters=num_filters * 4,
  90. filter_size=1,
  91. act=None,
  92. name=name + "_branch2c")
  93. if not shortcut:
  94. self.short = ConvBNLayer(
  95. num_channels=num_channels,
  96. num_filters=num_filters * 4,
  97. filter_size=1,
  98. stride=stride,
  99. name=name + "_branch1")
  100. self.shortcut = shortcut
  101. self._num_channels_out = num_filters * 4
  102. def forward(self, inputs):
  103. y = self.conv0(inputs)
  104. conv1 = self.conv1(y)
  105. conv2 = self.conv2(conv1)
  106. if self.shortcut:
  107. short = inputs
  108. else:
  109. short = self.short(inputs)
  110. y = paddle.add(x=short, y=conv2)
  111. y = F.relu(y)
  112. return y
  113. class BasicBlock(nn.Layer):
  114. def __init__(self,
  115. num_channels,
  116. num_filters,
  117. stride,
  118. shortcut=True,
  119. name=None):
  120. super(BasicBlock, self).__init__()
  121. self.stride = stride
  122. self.conv0 = ConvBNLayer(
  123. num_channels=num_channels,
  124. num_filters=num_filters,
  125. filter_size=3,
  126. stride=stride,
  127. act='relu',
  128. name=name + "_branch2a")
  129. self.conv1 = ConvBNLayer(
  130. num_channels=num_filters,
  131. num_filters=num_filters,
  132. filter_size=3,
  133. act=None,
  134. name=name + "_branch2b")
  135. if not shortcut:
  136. self.short = ConvBNLayer(
  137. num_channels=num_channels,
  138. num_filters=num_filters,
  139. filter_size=1,
  140. stride=stride,
  141. name=name + "_branch1")
  142. self.shortcut = shortcut
  143. def forward(self, inputs):
  144. y = self.conv0(inputs)
  145. conv1 = self.conv1(y)
  146. if self.shortcut:
  147. short = inputs
  148. else:
  149. short = self.short(inputs)
  150. y = paddle.add(x=short, y=conv1)
  151. y = F.relu(y)
  152. return y
  153. class ResNet_vc(nn.Layer):
  154. def __init__(self, layers=50, class_num=1000):
  155. super(ResNet_vc, self).__init__()
  156. self.layers = layers
  157. supported_layers = [18, 34, 50, 101, 152]
  158. assert layers in supported_layers, \
  159. "supported layers are {} but input layer is {}".format(
  160. supported_layers, layers)
  161. if layers == 18:
  162. depth = [2, 2, 2, 2]
  163. elif layers == 34 or layers == 50:
  164. depth = [3, 4, 6, 3]
  165. elif layers == 101:
  166. depth = [3, 4, 23, 3]
  167. elif layers == 152:
  168. depth = [3, 8, 36, 3]
  169. num_channels = [64, 256, 512,
  170. 1024] if layers >= 50 else [64, 64, 128, 256]
  171. num_filters = [64, 128, 256, 512]
  172. self.conv1_1 = ConvBNLayer(
  173. num_channels=3,
  174. num_filters=32,
  175. filter_size=3,
  176. stride=2,
  177. act='relu',
  178. name="conv1_1")
  179. self.conv1_2 = ConvBNLayer(
  180. num_channels=32,
  181. num_filters=32,
  182. filter_size=3,
  183. stride=1,
  184. act='relu',
  185. name="conv1_2")
  186. self.conv1_3 = ConvBNLayer(
  187. num_channels=32,
  188. num_filters=64,
  189. filter_size=3,
  190. stride=1,
  191. act='relu',
  192. name="conv1_3")
  193. self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
  194. self.block_list = []
  195. if layers >= 50:
  196. for block in range(len(depth)):
  197. shortcut = False
  198. for i in range(depth[block]):
  199. if layers in [101, 152] and block == 2:
  200. if i == 0:
  201. conv_name = "res" + str(block + 2) + "a"
  202. else:
  203. conv_name = "res" + str(block + 2) + "b" + str(i)
  204. else:
  205. conv_name = "res" + str(block + 2) + chr(97 + i)
  206. bottleneck_block = self.add_sublayer(
  207. 'bb_%d_%d' % (block, i),
  208. BottleneckBlock(
  209. num_channels=num_channels[block]
  210. if i == 0 else num_filters[block] * 4,
  211. num_filters=num_filters[block],
  212. stride=2 if i == 0 and block != 0 else 1,
  213. shortcut=shortcut,
  214. name=conv_name))
  215. self.block_list.append(bottleneck_block)
  216. shortcut = True
  217. else:
  218. for block in range(len(depth)):
  219. shortcut = False
  220. for i in range(depth[block]):
  221. conv_name = "res" + str(block + 2) + chr(97 + i)
  222. basic_block = self.add_sublayer(
  223. 'bb_%d_%d' % (block, i),
  224. BasicBlock(
  225. num_channels=num_channels[block]
  226. if i == 0 else num_filters[block],
  227. num_filters=num_filters[block],
  228. stride=2 if i == 0 and block != 0 else 1,
  229. shortcut=shortcut,
  230. name=conv_name))
  231. self.block_list.append(basic_block)
  232. shortcut = True
  233. self.pool2d_avg = AdaptiveAvgPool2D(1)
  234. self.pool2d_avg_channels = num_channels[-1] * 2
  235. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  236. self.out = Linear(
  237. self.pool2d_avg_channels,
  238. class_num,
  239. weight_attr=ParamAttr(
  240. initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
  241. bias_attr=ParamAttr(name="fc_0.b_0"))
  242. def forward(self, inputs):
  243. y = self.conv1_1(inputs)
  244. y = self.conv1_2(y)
  245. y = self.conv1_3(y)
  246. y = self.pool2d_max(y)
  247. for block in self.block_list:
  248. y = block(y)
  249. y = self.pool2d_avg(y)
  250. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  251. y = self.out(y)
  252. return y
  253. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  254. if pretrained is False:
  255. pass
  256. elif pretrained is True:
  257. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  258. elif isinstance(pretrained, str):
  259. load_dygraph_pretrain(model, pretrained)
  260. else:
  261. raise RuntimeError(
  262. "pretrained type is not available. Please use `string` or `boolean` type."
  263. )
  264. def ResNet50_vc(pretrained=False, use_ssld=False, **kwargs):
  265. model = ResNet_vc(layers=50, **kwargs)
  266. _load_pretrained(
  267. pretrained, model, MODEL_URLS["ResNet50_vc"], use_ssld=use_ssld)
  268. return model