shufflenet_v2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import ParamAttr, reshape, transpose, concat, split
  19. from paddle.nn import Layer, Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm, Linear
  20. from paddle.nn.initializer import KaimingNormal
  21. from paddle.nn.functional import swish
  22. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  23. MODEL_URLS = {
  24. "ShuffleNetV2_x0_25":
  25. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams",
  26. "ShuffleNetV2_x0_33":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams",
  28. "ShuffleNetV2_x0_5":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams",
  30. "ShuffleNetV2_x1_0":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams",
  32. "ShuffleNetV2_x1_5":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams",
  34. "ShuffleNetV2_x2_0":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams",
  36. "ShuffleNetV2_swish":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_swish_pretrained.pdparams"
  38. }
  39. __all__ = list(MODEL_URLS.keys())
  40. def channel_shuffle(x, groups):
  41. batch_size, num_channels, height, width = x.shape[0:4]
  42. channels_per_group = num_channels // groups
  43. # reshape
  44. x = reshape(
  45. x=x, shape=[batch_size, groups, channels_per_group, height, width])
  46. # transpose
  47. x = transpose(x=x, perm=[0, 2, 1, 3, 4])
  48. # flatten
  49. x = reshape(x=x, shape=[batch_size, num_channels, height, width])
  50. return x
  51. class ConvBNLayer(Layer):
  52. def __init__(
  53. self,
  54. in_channels,
  55. out_channels,
  56. kernel_size,
  57. stride,
  58. padding,
  59. groups=1,
  60. act=None,
  61. name=None, ):
  62. super(ConvBNLayer, self).__init__()
  63. self._conv = Conv2D(
  64. in_channels=in_channels,
  65. out_channels=out_channels,
  66. kernel_size=kernel_size,
  67. stride=stride,
  68. padding=padding,
  69. groups=groups,
  70. weight_attr=ParamAttr(
  71. initializer=KaimingNormal(), name=name + "_weights"),
  72. bias_attr=False)
  73. self._batch_norm = BatchNorm(
  74. out_channels,
  75. param_attr=ParamAttr(name=name + "_bn_scale"),
  76. bias_attr=ParamAttr(name=name + "_bn_offset"),
  77. act=act,
  78. moving_mean_name=name + "_bn_mean",
  79. moving_variance_name=name + "_bn_variance")
  80. def forward(self, inputs):
  81. y = self._conv(inputs)
  82. y = self._batch_norm(y)
  83. return y
  84. class InvertedResidual(Layer):
  85. def __init__(self,
  86. in_channels,
  87. out_channels,
  88. stride,
  89. act="relu",
  90. name=None):
  91. super(InvertedResidual, self).__init__()
  92. self._conv_pw = ConvBNLayer(
  93. in_channels=in_channels // 2,
  94. out_channels=out_channels // 2,
  95. kernel_size=1,
  96. stride=1,
  97. padding=0,
  98. groups=1,
  99. act=act,
  100. name='stage_' + name + '_conv1')
  101. self._conv_dw = ConvBNLayer(
  102. in_channels=out_channels // 2,
  103. out_channels=out_channels // 2,
  104. kernel_size=3,
  105. stride=stride,
  106. padding=1,
  107. groups=out_channels // 2,
  108. act=None,
  109. name='stage_' + name + '_conv2')
  110. self._conv_linear = ConvBNLayer(
  111. in_channels=out_channels // 2,
  112. out_channels=out_channels // 2,
  113. kernel_size=1,
  114. stride=1,
  115. padding=0,
  116. groups=1,
  117. act=act,
  118. name='stage_' + name + '_conv3')
  119. def forward(self, inputs):
  120. x1, x2 = split(
  121. inputs,
  122. num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
  123. axis=1)
  124. x2 = self._conv_pw(x2)
  125. x2 = self._conv_dw(x2)
  126. x2 = self._conv_linear(x2)
  127. out = concat([x1, x2], axis=1)
  128. return channel_shuffle(out, 2)
  129. class InvertedResidualDS(Layer):
  130. def __init__(self,
  131. in_channels,
  132. out_channels,
  133. stride,
  134. act="relu",
  135. name=None):
  136. super(InvertedResidualDS, self).__init__()
  137. # branch1
  138. self._conv_dw_1 = ConvBNLayer(
  139. in_channels=in_channels,
  140. out_channels=in_channels,
  141. kernel_size=3,
  142. stride=stride,
  143. padding=1,
  144. groups=in_channels,
  145. act=None,
  146. name='stage_' + name + '_conv4')
  147. self._conv_linear_1 = ConvBNLayer(
  148. in_channels=in_channels,
  149. out_channels=out_channels // 2,
  150. kernel_size=1,
  151. stride=1,
  152. padding=0,
  153. groups=1,
  154. act=act,
  155. name='stage_' + name + '_conv5')
  156. # branch2
  157. self._conv_pw_2 = ConvBNLayer(
  158. in_channels=in_channels,
  159. out_channels=out_channels // 2,
  160. kernel_size=1,
  161. stride=1,
  162. padding=0,
  163. groups=1,
  164. act=act,
  165. name='stage_' + name + '_conv1')
  166. self._conv_dw_2 = ConvBNLayer(
  167. in_channels=out_channels // 2,
  168. out_channels=out_channels // 2,
  169. kernel_size=3,
  170. stride=stride,
  171. padding=1,
  172. groups=out_channels // 2,
  173. act=None,
  174. name='stage_' + name + '_conv2')
  175. self._conv_linear_2 = ConvBNLayer(
  176. in_channels=out_channels // 2,
  177. out_channels=out_channels // 2,
  178. kernel_size=1,
  179. stride=1,
  180. padding=0,
  181. groups=1,
  182. act=act,
  183. name='stage_' + name + '_conv3')
  184. def forward(self, inputs):
  185. x1 = self._conv_dw_1(inputs)
  186. x1 = self._conv_linear_1(x1)
  187. x2 = self._conv_pw_2(inputs)
  188. x2 = self._conv_dw_2(x2)
  189. x2 = self._conv_linear_2(x2)
  190. out = concat([x1, x2], axis=1)
  191. return channel_shuffle(out, 2)
  192. class ShuffleNet(Layer):
  193. def __init__(self, class_num=1000, scale=1.0, act="relu"):
  194. super(ShuffleNet, self).__init__()
  195. self.scale = scale
  196. self.class_num = class_num
  197. stage_repeats = [4, 8, 4]
  198. if scale == 0.25:
  199. stage_out_channels = [-1, 24, 24, 48, 96, 512]
  200. elif scale == 0.33:
  201. stage_out_channels = [-1, 24, 32, 64, 128, 512]
  202. elif scale == 0.5:
  203. stage_out_channels = [-1, 24, 48, 96, 192, 1024]
  204. elif scale == 1.0:
  205. stage_out_channels = [-1, 24, 116, 232, 464, 1024]
  206. elif scale == 1.5:
  207. stage_out_channels = [-1, 24, 176, 352, 704, 1024]
  208. elif scale == 2.0:
  209. stage_out_channels = [-1, 24, 224, 488, 976, 2048]
  210. else:
  211. raise NotImplementedError("This scale size:[" + str(scale) +
  212. "] is not implemented!")
  213. # 1. conv1
  214. self._conv1 = ConvBNLayer(
  215. in_channels=3,
  216. out_channels=stage_out_channels[1],
  217. kernel_size=3,
  218. stride=2,
  219. padding=1,
  220. act=act,
  221. name='stage1_conv')
  222. self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
  223. # 2. bottleneck sequences
  224. self._block_list = []
  225. for stage_id, num_repeat in enumerate(stage_repeats):
  226. for i in range(num_repeat):
  227. if i == 0:
  228. block = self.add_sublayer(
  229. name=str(stage_id + 2) + '_' + str(i + 1),
  230. sublayer=InvertedResidualDS(
  231. in_channels=stage_out_channels[stage_id + 1],
  232. out_channels=stage_out_channels[stage_id + 2],
  233. stride=2,
  234. act=act,
  235. name=str(stage_id + 2) + '_' + str(i + 1)))
  236. else:
  237. block = self.add_sublayer(
  238. name=str(stage_id + 2) + '_' + str(i + 1),
  239. sublayer=InvertedResidual(
  240. in_channels=stage_out_channels[stage_id + 2],
  241. out_channels=stage_out_channels[stage_id + 2],
  242. stride=1,
  243. act=act,
  244. name=str(stage_id + 2) + '_' + str(i + 1)))
  245. self._block_list.append(block)
  246. # 3. last_conv
  247. self._last_conv = ConvBNLayer(
  248. in_channels=stage_out_channels[-2],
  249. out_channels=stage_out_channels[-1],
  250. kernel_size=1,
  251. stride=1,
  252. padding=0,
  253. act=act,
  254. name='conv5')
  255. # 4. pool
  256. self._pool2d_avg = AdaptiveAvgPool2D(1)
  257. self._out_c = stage_out_channels[-1]
  258. # 5. fc
  259. self._fc = Linear(
  260. stage_out_channels[-1],
  261. class_num,
  262. weight_attr=ParamAttr(name='fc6_weights'),
  263. bias_attr=ParamAttr(name='fc6_offset'))
  264. def forward(self, inputs):
  265. y = self._conv1(inputs)
  266. y = self._max_pool(y)
  267. for inv in self._block_list:
  268. y = inv(y)
  269. y = self._last_conv(y)
  270. y = self._pool2d_avg(y)
  271. y = paddle.flatten(y, start_axis=1, stop_axis=-1)
  272. y = self._fc(y)
  273. return y
  274. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  275. if pretrained is False:
  276. pass
  277. elif pretrained is True:
  278. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  279. elif isinstance(pretrained, str):
  280. load_dygraph_pretrain(model, pretrained)
  281. else:
  282. raise RuntimeError(
  283. "pretrained type is not available. Please use `string` or `boolean` type."
  284. )
  285. def ShuffleNetV2_x0_25(pretrained=False, use_ssld=False, **kwargs):
  286. model = ShuffleNet(scale=0.25, **kwargs)
  287. _load_pretrained(
  288. pretrained, model, MODEL_URLS["ShuffleNetV2_x0_25"], use_ssld=use_ssld)
  289. return model
  290. def ShuffleNetV2_x0_33(pretrained=False, use_ssld=False, **kwargs):
  291. model = ShuffleNet(scale=0.33, **kwargs)
  292. _load_pretrained(
  293. pretrained, model, MODEL_URLS["ShuffleNetV2_x0_33"], use_ssld=use_ssld)
  294. return model
  295. def ShuffleNetV2_x0_5(pretrained=False, use_ssld=False, **kwargs):
  296. model = ShuffleNet(scale=0.5, **kwargs)
  297. _load_pretrained(
  298. pretrained, model, MODEL_URLS["ShuffleNetV2_x0_5"], use_ssld=use_ssld)
  299. return model
  300. def ShuffleNetV2_x1_0(pretrained=False, use_ssld=False, **kwargs):
  301. model = ShuffleNet(scale=1.0, **kwargs)
  302. _load_pretrained(
  303. pretrained, model, MODEL_URLS["ShuffleNetV2_x1_0"], use_ssld=use_ssld)
  304. return model
  305. def ShuffleNetV2_x1_5(pretrained=False, use_ssld=False, **kwargs):
  306. model = ShuffleNet(scale=1.5, **kwargs)
  307. _load_pretrained(
  308. pretrained, model, MODEL_URLS["ShuffleNetV2_x1_5"], use_ssld=use_ssld)
  309. return model
  310. def ShuffleNetV2_x2_0(pretrained=False, use_ssld=False, **kwargs):
  311. model = ShuffleNet(scale=2.0, **kwargs)
  312. _load_pretrained(
  313. pretrained, model, MODEL_URLS["ShuffleNetV2_x2_0"], use_ssld=use_ssld)
  314. return model
  315. def ShuffleNetV2_swish(pretrained=False, use_ssld=False, **kwargs):
  316. model = ShuffleNet(scale=1.0, act="swish", **kwargs)
  317. _load_pretrained(
  318. pretrained, model, MODEL_URLS["ShuffleNetV2_swish"], use_ssld=use_ssld)
  319. return model