mobilenet_v3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import, division, print_function
  15. import paddle
  16. import paddle.nn as nn
  17. from paddle import ParamAttr
  18. from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
  19. from paddle.regularizer import L2Decay
  20. from paddlex.ppcls.arch.backbone.base.theseus_layer import TheseusLayer
  21. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  22. MODEL_URLS = {
  23. "MobileNetV3_small_x0_35":
  24. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_pretrained.pdparams",
  25. "MobileNetV3_small_x0_5":
  26. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_5_pretrained.pdparams",
  27. "MobileNetV3_small_x0_75":
  28. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_75_pretrained.pdparams",
  29. "MobileNetV3_small_x1_0":
  30. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_pretrained.pdparams",
  31. "MobileNetV3_small_x1_25":
  32. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_25_pretrained.pdparams",
  33. "MobileNetV3_large_x0_35":
  34. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_35_pretrained.pdparams",
  35. "MobileNetV3_large_x0_5":
  36. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_5_pretrained.pdparams",
  37. "MobileNetV3_large_x0_75":
  38. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_75_pretrained.pdparams",
  39. "MobileNetV3_large_x1_0":
  40. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_pretrained.pdparams",
  41. "MobileNetV3_large_x1_25":
  42. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_25_pretrained.pdparams",
  43. }
  44. __all__ = MODEL_URLS.keys()
  45. # "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
  46. # The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
  47. # k: kernel_size
  48. # exp: middle channel number in depthwise block
  49. # c: output channel number in depthwise block
  50. # se: whether to use SE block
  51. # act: which activation to use
  52. # s: stride in depthwise block
  53. NET_CONFIG = {
  54. "large": [
  55. # k, exp, c, se, act, s
  56. [3, 16, 16, False, "relu", 1],
  57. [3, 64, 24, False, "relu", 2],
  58. [3, 72, 24, False, "relu", 1],
  59. [5, 72, 40, True, "relu", 2],
  60. [5, 120, 40, True, "relu", 1],
  61. [5, 120, 40, True, "relu", 1],
  62. [3, 240, 80, False, "hardswish", 2],
  63. [3, 200, 80, False, "hardswish", 1],
  64. [3, 184, 80, False, "hardswish", 1],
  65. [3, 184, 80, False, "hardswish", 1],
  66. [3, 480, 112, True, "hardswish", 1],
  67. [3, 672, 112, True, "hardswish", 1],
  68. [5, 672, 160, True, "hardswish", 2],
  69. [5, 960, 160, True, "hardswish", 1],
  70. [5, 960, 160, True, "hardswish", 1],
  71. ],
  72. "small": [
  73. # k, exp, c, se, act, s
  74. [3, 16, 16, True, "relu", 2],
  75. [3, 72, 24, False, "relu", 2],
  76. [3, 88, 24, False, "relu", 1],
  77. [5, 96, 40, True, "hardswish", 2],
  78. [5, 240, 40, True, "hardswish", 1],
  79. [5, 240, 40, True, "hardswish", 1],
  80. [5, 120, 48, True, "hardswish", 1],
  81. [5, 144, 48, True, "hardswish", 1],
  82. [5, 288, 96, True, "hardswish", 2],
  83. [5, 576, 96, True, "hardswish", 1],
  84. [5, 576, 96, True, "hardswish", 1],
  85. ]
  86. }
  87. # first conv output channel number in MobileNetV3
  88. STEM_CONV_NUMBER = 16
  89. # last second conv output channel for "small"
  90. LAST_SECOND_CONV_SMALL = 576
  91. # last second conv output channel for "large"
  92. LAST_SECOND_CONV_LARGE = 960
  93. # last conv output channel number for "large" and "small"
  94. LAST_CONV = 1280
  95. def _make_divisible(v, divisor=8, min_value=None):
  96. if min_value is None:
  97. min_value = divisor
  98. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  99. if new_v < 0.9 * v:
  100. new_v += divisor
  101. return new_v
  102. def _create_act(act):
  103. if act == "hardswish":
  104. return nn.Hardswish()
  105. elif act == "relu":
  106. return nn.ReLU()
  107. elif act is None:
  108. return None
  109. else:
  110. raise RuntimeError(
  111. "The activation function is not supported: {}".format(act))
  112. class MobileNetV3(TheseusLayer):
  113. """
  114. MobileNetV3
  115. Args:
  116. config: list. MobileNetV3 depthwise blocks config.
  117. scale: float=1.0. The coefficient that controls the size of network parameters.
  118. class_num: int=1000. The number of classes.
  119. inplanes: int=16. The output channel number of first convolution layer.
  120. class_squeeze: int=960. The output channel number of penultimate convolution layer.
  121. class_expand: int=1280. The output channel number of last convolution layer.
  122. dropout_prob: float=0.2. Probability of setting units to zero.
  123. Returns:
  124. model: nn.Layer. Specific MobileNetV3 model depends on args.
  125. """
  126. def __init__(self,
  127. config,
  128. scale=1.0,
  129. class_num=1000,
  130. inplanes=STEM_CONV_NUMBER,
  131. class_squeeze=LAST_SECOND_CONV_LARGE,
  132. class_expand=LAST_CONV,
  133. dropout_prob=0.2,
  134. return_patterns=None):
  135. super().__init__()
  136. self.cfg = config
  137. self.scale = scale
  138. self.inplanes = inplanes
  139. self.class_squeeze = class_squeeze
  140. self.class_expand = class_expand
  141. self.class_num = class_num
  142. self.conv = ConvBNLayer(
  143. in_c=3,
  144. out_c=_make_divisible(self.inplanes * self.scale),
  145. filter_size=3,
  146. stride=2,
  147. padding=1,
  148. num_groups=1,
  149. if_act=True,
  150. act="hardswish")
  151. self.blocks = nn.Sequential(*[
  152. ResidualUnit(
  153. in_c=_make_divisible(self.inplanes * self.scale if i == 0 else
  154. self.cfg[i - 1][2] * self.scale),
  155. mid_c=_make_divisible(self.scale * exp),
  156. out_c=_make_divisible(self.scale * c),
  157. filter_size=k,
  158. stride=s,
  159. use_se=se,
  160. act=act) for i, (k, exp, c, se, act, s) in enumerate(self.cfg)
  161. ])
  162. self.last_second_conv = ConvBNLayer(
  163. in_c=_make_divisible(self.cfg[-1][2] * self.scale),
  164. out_c=_make_divisible(self.scale * self.class_squeeze),
  165. filter_size=1,
  166. stride=1,
  167. padding=0,
  168. num_groups=1,
  169. if_act=True,
  170. act="hardswish")
  171. self.avg_pool = AdaptiveAvgPool2D(1)
  172. self.last_conv = Conv2D(
  173. in_channels=_make_divisible(self.scale * self.class_squeeze),
  174. out_channels=self.class_expand,
  175. kernel_size=1,
  176. stride=1,
  177. padding=0,
  178. bias_attr=False)
  179. self.hardswish = nn.Hardswish()
  180. self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
  181. self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
  182. self.fc = Linear(self.class_expand, class_num)
  183. if return_patterns is not None:
  184. self.update_res(return_patterns)
  185. self.register_forward_post_hook(self._return_dict_hook)
  186. def forward(self, x):
  187. x = self.conv(x)
  188. x = self.blocks(x)
  189. x = self.last_second_conv(x)
  190. x = self.avg_pool(x)
  191. x = self.last_conv(x)
  192. x = self.hardswish(x)
  193. x = self.dropout(x)
  194. x = self.flatten(x)
  195. x = self.fc(x)
  196. return x
  197. class ConvBNLayer(TheseusLayer):
  198. def __init__(self,
  199. in_c,
  200. out_c,
  201. filter_size,
  202. stride,
  203. padding,
  204. num_groups=1,
  205. if_act=True,
  206. act=None):
  207. super().__init__()
  208. self.conv = Conv2D(
  209. in_channels=in_c,
  210. out_channels=out_c,
  211. kernel_size=filter_size,
  212. stride=stride,
  213. padding=padding,
  214. groups=num_groups,
  215. bias_attr=False)
  216. self.bn = BatchNorm(
  217. num_channels=out_c,
  218. act=None,
  219. param_attr=ParamAttr(regularizer=L2Decay(0.0)),
  220. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  221. self.if_act = if_act
  222. self.act = _create_act(act)
  223. def forward(self, x):
  224. x = self.conv(x)
  225. x = self.bn(x)
  226. if self.if_act:
  227. x = self.act(x)
  228. return x
  229. class ResidualUnit(TheseusLayer):
  230. def __init__(self,
  231. in_c,
  232. mid_c,
  233. out_c,
  234. filter_size,
  235. stride,
  236. use_se,
  237. act=None):
  238. super().__init__()
  239. self.if_shortcut = stride == 1 and in_c == out_c
  240. self.if_se = use_se
  241. self.expand_conv = ConvBNLayer(
  242. in_c=in_c,
  243. out_c=mid_c,
  244. filter_size=1,
  245. stride=1,
  246. padding=0,
  247. if_act=True,
  248. act=act)
  249. self.bottleneck_conv = ConvBNLayer(
  250. in_c=mid_c,
  251. out_c=mid_c,
  252. filter_size=filter_size,
  253. stride=stride,
  254. padding=int((filter_size - 1) // 2),
  255. num_groups=mid_c,
  256. if_act=True,
  257. act=act)
  258. if self.if_se:
  259. self.mid_se = SEModule(mid_c)
  260. self.linear_conv = ConvBNLayer(
  261. in_c=mid_c,
  262. out_c=out_c,
  263. filter_size=1,
  264. stride=1,
  265. padding=0,
  266. if_act=False,
  267. act=None)
  268. def forward(self, x):
  269. identity = x
  270. x = self.expand_conv(x)
  271. x = self.bottleneck_conv(x)
  272. if self.if_se:
  273. x = self.mid_se(x)
  274. x = self.linear_conv(x)
  275. if self.if_shortcut:
  276. x = paddle.add(identity, x)
  277. return x
  278. # nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
  279. class Hardsigmoid(TheseusLayer):
  280. def __init__(self, slope=0.2, offset=0.5):
  281. super().__init__()
  282. self.slope = slope
  283. self.offset = offset
  284. def forward(self, x):
  285. return nn.functional.hardsigmoid(
  286. x, slope=self.slope, offset=self.offset)
  287. class SEModule(TheseusLayer):
  288. def __init__(self, channel, reduction=4):
  289. super().__init__()
  290. self.avg_pool = AdaptiveAvgPool2D(1)
  291. self.conv1 = Conv2D(
  292. in_channels=channel,
  293. out_channels=channel // reduction,
  294. kernel_size=1,
  295. stride=1,
  296. padding=0)
  297. self.relu = nn.ReLU()
  298. self.conv2 = Conv2D(
  299. in_channels=channel // reduction,
  300. out_channels=channel,
  301. kernel_size=1,
  302. stride=1,
  303. padding=0)
  304. self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
  305. def forward(self, x):
  306. identity = x
  307. x = self.avg_pool(x)
  308. x = self.conv1(x)
  309. x = self.relu(x)
  310. x = self.conv2(x)
  311. x = self.hardsigmoid(x)
  312. return paddle.multiply(x=identity, y=x)
  313. def _load_pretrained(pretrained, model, model_url, use_ssld):
  314. if pretrained is False:
  315. pass
  316. elif pretrained is True:
  317. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  318. elif isinstance(pretrained, str):
  319. load_dygraph_pretrain(model, pretrained)
  320. else:
  321. raise RuntimeError(
  322. "pretrained type is not available. Please use `string` or `boolean` type."
  323. )
  324. def MobileNetV3_small_x0_35(pretrained=False, use_ssld=False, **kwargs):
  325. """
  326. MobileNetV3_small_x0_35
  327. Args:
  328. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  329. If str, means the path of the pretrained model.
  330. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  331. Returns:
  332. model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
  333. """
  334. model = MobileNetV3(
  335. config=NET_CONFIG["small"],
  336. scale=0.35,
  337. class_squeeze=LAST_SECOND_CONV_SMALL,
  338. **kwargs)
  339. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_35"],
  340. use_ssld)
  341. return model
  342. def MobileNetV3_small_x0_5(pretrained=False, use_ssld=False, **kwargs):
  343. """
  344. MobileNetV3_small_x0_5
  345. Args:
  346. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  347. If str, means the path of the pretrained model.
  348. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  349. Returns:
  350. model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
  351. """
  352. model = MobileNetV3(
  353. config=NET_CONFIG["small"],
  354. scale=0.5,
  355. class_squeeze=LAST_SECOND_CONV_SMALL,
  356. **kwargs)
  357. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_5"],
  358. use_ssld)
  359. return model
  360. def MobileNetV3_small_x0_75(pretrained=False, use_ssld=False, **kwargs):
  361. """
  362. MobileNetV3_small_x0_75
  363. Args:
  364. pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
  365. if str, means the path of the pretrained model.
  366. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  367. Returns:
  368. model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
  369. """
  370. model = MobileNetV3(
  371. config=NET_CONFIG["small"],
  372. scale=0.75,
  373. class_squeeze=LAST_SECOND_CONV_SMALL,
  374. **kwargs)
  375. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_75"],
  376. use_ssld)
  377. return model
  378. def MobileNetV3_small_x1_0(pretrained=False, use_ssld=False, **kwargs):
  379. """
  380. MobileNetV3_small_x1_0
  381. Args:
  382. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  383. If str, means the path of the pretrained model.
  384. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  385. Returns:
  386. model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
  387. """
  388. model = MobileNetV3(
  389. config=NET_CONFIG["small"],
  390. scale=1.0,
  391. class_squeeze=LAST_SECOND_CONV_SMALL,
  392. **kwargs)
  393. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_0"],
  394. use_ssld)
  395. return model
  396. def MobileNetV3_small_x1_25(pretrained=False, use_ssld=False, **kwargs):
  397. """
  398. MobileNetV3_small_x1_25
  399. Args:
  400. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  401. If str, means the path of the pretrained model.
  402. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  403. Returns:
  404. model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
  405. """
  406. model = MobileNetV3(
  407. config=NET_CONFIG["small"],
  408. scale=1.25,
  409. class_squeeze=LAST_SECOND_CONV_SMALL,
  410. **kwargs)
  411. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_25"],
  412. use_ssld)
  413. return model
  414. def MobileNetV3_large_x0_35(pretrained=False, use_ssld=False, **kwargs):
  415. """
  416. MobileNetV3_large_x0_35
  417. Args:
  418. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  419. If str, means the path of the pretrained model.
  420. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  421. Returns:
  422. model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
  423. """
  424. model = MobileNetV3(
  425. config=NET_CONFIG["large"],
  426. scale=0.35,
  427. class_squeeze=LAST_SECOND_CONV_LARGE,
  428. **kwargs)
  429. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_35"],
  430. use_ssld)
  431. return model
  432. def MobileNetV3_large_x0_5(pretrained=False, use_ssld=False, **kwargs):
  433. """
  434. MobileNetV3_large_x0_5
  435. Args:
  436. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  437. If str, means the path of the pretrained model.
  438. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  439. Returns:
  440. model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
  441. """
  442. model = MobileNetV3(
  443. config=NET_CONFIG["large"],
  444. scale=0.5,
  445. class_squeeze=LAST_SECOND_CONV_LARGE,
  446. **kwargs)
  447. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_5"],
  448. use_ssld)
  449. return model
  450. def MobileNetV3_large_x0_75(pretrained=False, use_ssld=False, **kwargs):
  451. """
  452. MobileNetV3_large_x0_75
  453. Args:
  454. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  455. If str, means the path of the pretrained model.
  456. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  457. Returns:
  458. model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
  459. """
  460. model = MobileNetV3(
  461. config=NET_CONFIG["large"],
  462. scale=0.75,
  463. class_squeeze=LAST_SECOND_CONV_LARGE,
  464. **kwargs)
  465. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_75"],
  466. use_ssld)
  467. return model
  468. def MobileNetV3_large_x1_0(pretrained=False, use_ssld=False, **kwargs):
  469. """
  470. MobileNetV3_large_x1_0
  471. Args:
  472. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  473. If str, means the path of the pretrained model.
  474. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  475. Returns:
  476. model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
  477. """
  478. model = MobileNetV3(
  479. config=NET_CONFIG["large"],
  480. scale=1.0,
  481. class_squeeze=LAST_SECOND_CONV_LARGE,
  482. **kwargs)
  483. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_0"],
  484. use_ssld)
  485. return model
  486. def MobileNetV3_large_x1_25(pretrained=False, use_ssld=False, **kwargs):
  487. """
  488. MobileNetV3_large_x1_25
  489. Args:
  490. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  491. If str, means the path of the pretrained model.
  492. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  493. Returns:
  494. model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
  495. """
  496. model = MobileNetV3(
  497. config=NET_CONFIG["large"],
  498. scale=1.25,
  499. class_squeeze=LAST_SECOND_CONV_LARGE,
  500. **kwargs)
  501. _load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_25"],
  502. use_ssld)
  503. return model