ghostnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch
  15. import math
  16. import paddle
  17. from paddle import ParamAttr
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.nn import Conv2D, BatchNorm, AdaptiveAvgPool2D, Linear
  21. from paddle.regularizer import L2Decay
  22. from paddle.nn.initializer import Uniform, KaimingNormal
  23. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  24. MODEL_URLS = {
  25. "GhostNet_x0_5":
  26. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x0_5_pretrained.pdparams",
  27. "GhostNet_x1_0":
  28. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams",
  29. "GhostNet_x1_3":
  30. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams",
  31. }
  32. __all__ = list(MODEL_URLS.keys())
  33. class ConvBNLayer(nn.Layer):
  34. def __init__(self,
  35. in_channels,
  36. out_channels,
  37. kernel_size,
  38. stride=1,
  39. groups=1,
  40. act="relu",
  41. name=None):
  42. super(ConvBNLayer, self).__init__()
  43. self._conv = Conv2D(
  44. in_channels=in_channels,
  45. out_channels=out_channels,
  46. kernel_size=kernel_size,
  47. stride=stride,
  48. padding=(kernel_size - 1) // 2,
  49. groups=groups,
  50. weight_attr=ParamAttr(
  51. initializer=KaimingNormal(), name=name + "_weights"),
  52. bias_attr=False)
  53. bn_name = name + "_bn"
  54. self._batch_norm = BatchNorm(
  55. num_channels=out_channels,
  56. act=act,
  57. param_attr=ParamAttr(
  58. name=bn_name + "_scale", regularizer=L2Decay(0.0)),
  59. bias_attr=ParamAttr(
  60. name=bn_name + "_offset", regularizer=L2Decay(0.0)),
  61. moving_mean_name=bn_name + "_mean",
  62. moving_variance_name=bn_name + "_variance")
  63. def forward(self, inputs):
  64. y = self._conv(inputs)
  65. y = self._batch_norm(y)
  66. return y
  67. class SEBlock(nn.Layer):
  68. def __init__(self, num_channels, reduction_ratio=4, name=None):
  69. super(SEBlock, self).__init__()
  70. self.pool2d_gap = AdaptiveAvgPool2D(1)
  71. self._num_channels = num_channels
  72. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  73. med_ch = num_channels // reduction_ratio
  74. self.squeeze = Linear(
  75. num_channels,
  76. med_ch,
  77. weight_attr=ParamAttr(
  78. initializer=Uniform(-stdv, stdv), name=name + "_1_weights"),
  79. bias_attr=ParamAttr(name=name + "_1_offset"))
  80. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  81. self.excitation = Linear(
  82. med_ch,
  83. num_channels,
  84. weight_attr=ParamAttr(
  85. initializer=Uniform(-stdv, stdv), name=name + "_2_weights"),
  86. bias_attr=ParamAttr(name=name + "_2_offset"))
  87. def forward(self, inputs):
  88. pool = self.pool2d_gap(inputs)
  89. pool = paddle.squeeze(pool, axis=[2, 3])
  90. squeeze = self.squeeze(pool)
  91. squeeze = F.relu(squeeze)
  92. excitation = self.excitation(squeeze)
  93. excitation = paddle.clip(x=excitation, min=0, max=1)
  94. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  95. out = paddle.multiply(inputs, excitation)
  96. return out
  97. class GhostModule(nn.Layer):
  98. def __init__(self,
  99. in_channels,
  100. output_channels,
  101. kernel_size=1,
  102. ratio=2,
  103. dw_size=3,
  104. stride=1,
  105. relu=True,
  106. name=None):
  107. super(GhostModule, self).__init__()
  108. init_channels = int(math.ceil(output_channels / ratio))
  109. new_channels = int(init_channels * (ratio - 1))
  110. self.primary_conv = ConvBNLayer(
  111. in_channels=in_channels,
  112. out_channels=init_channels,
  113. kernel_size=kernel_size,
  114. stride=stride,
  115. groups=1,
  116. act="relu" if relu else None,
  117. name=name + "_primary_conv")
  118. self.cheap_operation = ConvBNLayer(
  119. in_channels=init_channels,
  120. out_channels=new_channels,
  121. kernel_size=dw_size,
  122. stride=1,
  123. groups=init_channels,
  124. act="relu" if relu else None,
  125. name=name + "_cheap_operation")
  126. def forward(self, inputs):
  127. x = self.primary_conv(inputs)
  128. y = self.cheap_operation(x)
  129. out = paddle.concat([x, y], axis=1)
  130. return out
  131. class GhostBottleneck(nn.Layer):
  132. def __init__(self,
  133. in_channels,
  134. hidden_dim,
  135. output_channels,
  136. kernel_size,
  137. stride,
  138. use_se,
  139. name=None):
  140. super(GhostBottleneck, self).__init__()
  141. self._stride = stride
  142. self._use_se = use_se
  143. self._num_channels = in_channels
  144. self._output_channels = output_channels
  145. self.ghost_module_1 = GhostModule(
  146. in_channels=in_channels,
  147. output_channels=hidden_dim,
  148. kernel_size=1,
  149. stride=1,
  150. relu=True,
  151. name=name + "_ghost_module_1")
  152. if stride == 2:
  153. self.depthwise_conv = ConvBNLayer(
  154. in_channels=hidden_dim,
  155. out_channels=hidden_dim,
  156. kernel_size=kernel_size,
  157. stride=stride,
  158. groups=hidden_dim,
  159. act=None,
  160. name=name +
  161. "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  162. )
  163. if use_se:
  164. self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se")
  165. self.ghost_module_2 = GhostModule(
  166. in_channels=hidden_dim,
  167. output_channels=output_channels,
  168. kernel_size=1,
  169. relu=False,
  170. name=name + "_ghost_module_2")
  171. if stride != 1 or in_channels != output_channels:
  172. self.shortcut_depthwise = ConvBNLayer(
  173. in_channels=in_channels,
  174. out_channels=in_channels,
  175. kernel_size=kernel_size,
  176. stride=stride,
  177. groups=in_channels,
  178. act=None,
  179. name=name +
  180. "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  181. )
  182. self.shortcut_conv = ConvBNLayer(
  183. in_channels=in_channels,
  184. out_channels=output_channels,
  185. kernel_size=1,
  186. stride=1,
  187. groups=1,
  188. act=None,
  189. name=name + "_shortcut_conv")
  190. def forward(self, inputs):
  191. x = self.ghost_module_1(inputs)
  192. if self._stride == 2:
  193. x = self.depthwise_conv(x)
  194. if self._use_se:
  195. x = self.se_block(x)
  196. x = self.ghost_module_2(x)
  197. if self._stride == 1 and self._num_channels == self._output_channels:
  198. shortcut = inputs
  199. else:
  200. shortcut = self.shortcut_depthwise(inputs)
  201. shortcut = self.shortcut_conv(shortcut)
  202. return paddle.add(x=x, y=shortcut)
  203. class GhostNet(nn.Layer):
  204. def __init__(self, scale, class_num=1000):
  205. super(GhostNet, self).__init__()
  206. self.cfgs = [
  207. # k, t, c, SE, s
  208. [3, 16, 16, 0, 1],
  209. [3, 48, 24, 0, 2],
  210. [3, 72, 24, 0, 1],
  211. [5, 72, 40, 1, 2],
  212. [5, 120, 40, 1, 1],
  213. [3, 240, 80, 0, 2],
  214. [3, 200, 80, 0, 1],
  215. [3, 184, 80, 0, 1],
  216. [3, 184, 80, 0, 1],
  217. [3, 480, 112, 1, 1],
  218. [3, 672, 112, 1, 1],
  219. [5, 672, 160, 1, 2],
  220. [5, 960, 160, 0, 1],
  221. [5, 960, 160, 1, 1],
  222. [5, 960, 160, 0, 1],
  223. [5, 960, 160, 1, 1]
  224. ]
  225. self.scale = scale
  226. output_channels = int(self._make_divisible(16 * self.scale, 4))
  227. self.conv1 = ConvBNLayer(
  228. in_channels=3,
  229. out_channels=output_channels,
  230. kernel_size=3,
  231. stride=2,
  232. groups=1,
  233. act="relu",
  234. name="conv1")
  235. # build inverted residual blocks
  236. idx = 0
  237. self.ghost_bottleneck_list = []
  238. for k, exp_size, c, use_se, s in self.cfgs:
  239. in_channels = output_channels
  240. output_channels = int(self._make_divisible(c * self.scale, 4))
  241. hidden_dim = int(self._make_divisible(exp_size * self.scale, 4))
  242. ghost_bottleneck = self.add_sublayer(
  243. name="_ghostbottleneck_" + str(idx),
  244. sublayer=GhostBottleneck(
  245. in_channels=in_channels,
  246. hidden_dim=hidden_dim,
  247. output_channels=output_channels,
  248. kernel_size=k,
  249. stride=s,
  250. use_se=use_se,
  251. name="_ghostbottleneck_" + str(idx)))
  252. self.ghost_bottleneck_list.append(ghost_bottleneck)
  253. idx += 1
  254. # build last several layers
  255. in_channels = output_channels
  256. output_channels = int(self._make_divisible(exp_size * self.scale, 4))
  257. self.conv_last = ConvBNLayer(
  258. in_channels=in_channels,
  259. out_channels=output_channels,
  260. kernel_size=1,
  261. stride=1,
  262. groups=1,
  263. act="relu",
  264. name="conv_last")
  265. self.pool2d_gap = AdaptiveAvgPool2D(1)
  266. in_channels = output_channels
  267. self._fc0_output_channels = 1280
  268. self.fc_0 = ConvBNLayer(
  269. in_channels=in_channels,
  270. out_channels=self._fc0_output_channels,
  271. kernel_size=1,
  272. stride=1,
  273. act="relu",
  274. name="fc_0")
  275. self.dropout = nn.Dropout(p=0.2)
  276. stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
  277. self.fc_1 = Linear(
  278. self._fc0_output_channels,
  279. class_num,
  280. weight_attr=ParamAttr(
  281. name="fc_1_weights", initializer=Uniform(-stdv, stdv)),
  282. bias_attr=ParamAttr(name="fc_1_offset"))
  283. def forward(self, inputs):
  284. x = self.conv1(inputs)
  285. for ghost_bottleneck in self.ghost_bottleneck_list:
  286. x = ghost_bottleneck(x)
  287. x = self.conv_last(x)
  288. x = self.pool2d_gap(x)
  289. x = self.fc_0(x)
  290. x = self.dropout(x)
  291. x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
  292. x = self.fc_1(x)
  293. return x
  294. def _make_divisible(self, v, divisor, min_value=None):
  295. """
  296. This function is taken from the original tf repo.
  297. It ensures that all layers have a channel number that is divisible by 8
  298. It can be seen here:
  299. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  300. """
  301. if min_value is None:
  302. min_value = divisor
  303. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  304. # Make sure that round down does not go down by more than 10%.
  305. if new_v < 0.9 * v:
  306. new_v += divisor
  307. return new_v
  308. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  309. if pretrained is False:
  310. pass
  311. elif pretrained is True:
  312. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  313. elif isinstance(pretrained, str):
  314. load_dygraph_pretrain(model, pretrained)
  315. else:
  316. raise RuntimeError(
  317. "pretrained type is not available. Please use `string` or `boolean` type."
  318. )
  319. def GhostNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
  320. model = GhostNet(scale=0.5, **kwargs)
  321. _load_pretrained(
  322. pretrained, model, MODEL_URLS["GhostNet_x0_5"], use_ssld=use_ssld)
  323. return model
  324. def GhostNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
  325. model = GhostNet(scale=1.0, **kwargs)
  326. _load_pretrained(
  327. pretrained, model, MODEL_URLS["GhostNet_x1_0"], use_ssld=use_ssld)
  328. return model
  329. def GhostNet_x1_3(pretrained=False, use_ssld=False, **kwargs):
  330. model = GhostNet(scale=1.3, **kwargs)
  331. _load_pretrained(
  332. pretrained, model, MODEL_URLS["GhostNet_x1_3"], use_ssld=use_ssld)
  333. return model