mixnet.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. MixNet for ImageNet-1K, implemented in Paddle.
  16. Original paper: 'MixConv: Mixed Depthwise Convolutional Kernels,'
  17. https://arxiv.org/abs/1907.09595.
  18. """
  19. import os
  20. from inspect import isfunction
  21. from functools import reduce
  22. import paddle
  23. import paddle.nn as nn
  24. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  25. MODEL_URLS = {
  26. "MixNet_S":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_S_pretrained.pdparams",
  28. "MixNet_M":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_M_pretrained.pdparams",
  30. "MixNet_L":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_L_pretrained.pdparams"
  32. }
  33. __all__ = list(MODEL_URLS.keys())
  34. class Identity(nn.Layer):
  35. """
  36. Identity block.
  37. """
  38. def __init__(self):
  39. super(Identity, self).__init__()
  40. def forward(self, x):
  41. return x
  42. def round_channels(channels, divisor=8):
  43. """
  44. Round weighted channel number (make divisible operation).
  45. Parameters:
  46. ----------
  47. channels : int or float
  48. Original number of channels.
  49. divisor : int, default 8
  50. Alignment value.
  51. Returns:
  52. -------
  53. int
  54. Weighted number of channels.
  55. """
  56. rounded_channels = max(
  57. int(channels + divisor / 2.0) // divisor * divisor, divisor)
  58. if float(rounded_channels) < 0.9 * channels:
  59. rounded_channels += divisor
  60. return rounded_channels
  61. def get_activation_layer(activation):
  62. """
  63. Create activation layer from string/function.
  64. Parameters:
  65. ----------
  66. activation : function, or str, or nn.Module
  67. Activation function or name of activation function.
  68. Returns:
  69. -------
  70. nn.Module
  71. Activation layer.
  72. """
  73. assert activation is not None
  74. if isfunction(activation):
  75. return activation()
  76. elif isinstance(activation, str):
  77. if activation == "relu":
  78. return nn.ReLU()
  79. elif activation == "relu6":
  80. return nn.ReLU6()
  81. elif activation == "swish":
  82. return nn.Swish()
  83. elif activation == "hswish":
  84. return nn.Hardswish()
  85. elif activation == "sigmoid":
  86. return nn.Sigmoid()
  87. elif activation == "hsigmoid":
  88. return nn.Hardsigmoid()
  89. elif activation == "identity":
  90. return Identity()
  91. else:
  92. raise NotImplementedError()
  93. else:
  94. assert isinstance(activation, nn.Layer)
  95. return activation
  96. class ConvBlock(nn.Layer):
  97. """
  98. Standard convolution block with Batch normalization and activation.
  99. Parameters:
  100. ----------
  101. in_channels : int
  102. Number of input channels.
  103. out_channels : int
  104. Number of output channels.
  105. kernel_size : int or tuple/list of 2 int
  106. Convolution window size.
  107. stride : int or tuple/list of 2 int
  108. Strides of the convolution.
  109. padding : int, or tuple/list of 2 int, or tuple/list of 4 int
  110. Padding value for convolution layer.
  111. dilation : int or tuple/list of 2 int, default 1
  112. Dilation value for convolution layer.
  113. groups : int, default 1
  114. Number of groups.
  115. bias : bool, default False
  116. Whether the layer uses a bias vector.
  117. use_bn : bool, default True
  118. Whether to use BatchNorm layer.
  119. bn_eps : float, default 1e-5
  120. Small float added to variance in Batch norm.
  121. activation : function or str or None, default nn.ReLU()
  122. Activation function or name of activation function.
  123. """
  124. def __init__(self,
  125. in_channels,
  126. out_channels,
  127. kernel_size,
  128. stride,
  129. padding,
  130. dilation=1,
  131. groups=1,
  132. bias=False,
  133. use_bn=True,
  134. bn_eps=1e-5,
  135. activation=nn.ReLU()):
  136. super(ConvBlock, self).__init__()
  137. self.activate = (activation is not None)
  138. self.use_bn = use_bn
  139. self.use_pad = (isinstance(padding, (list, tuple)) and
  140. (len(padding) == 4))
  141. if self.use_pad:
  142. self.pad = padding
  143. self.conv = nn.Conv2D(
  144. in_channels=in_channels,
  145. out_channels=out_channels,
  146. kernel_size=kernel_size,
  147. stride=stride,
  148. padding=padding,
  149. dilation=dilation,
  150. groups=groups,
  151. bias_attr=bias,
  152. weight_attr=None)
  153. if self.use_bn:
  154. self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
  155. if self.activate:
  156. self.activ = get_activation_layer(activation)
  157. def forward(self, x):
  158. x = self.conv(x)
  159. if self.use_bn:
  160. x = self.bn(x)
  161. if self.activate:
  162. x = self.activ(x)
  163. return x
  164. class SEBlock(nn.Layer):
  165. def __init__(self,
  166. channels,
  167. reduction=16,
  168. mid_channels=None,
  169. round_mid=False,
  170. use_conv=True,
  171. mid_activation=nn.ReLU(),
  172. out_activation=nn.Sigmoid()):
  173. super(SEBlock, self).__init__()
  174. self.use_conv = use_conv
  175. if mid_channels is None:
  176. mid_channels = channels // reduction if not round_mid else round_channels(
  177. float(channels) / reduction)
  178. self.pool = nn.AdaptiveAvgPool2D(output_size=1)
  179. if use_conv:
  180. self.conv1 = nn.Conv2D(
  181. in_channels=channels,
  182. out_channels=mid_channels,
  183. kernel_size=1,
  184. stride=1,
  185. groups=1,
  186. bias_attr=True,
  187. weight_attr=None)
  188. else:
  189. self.fc1 = nn.Linear(
  190. in_features=channels, out_features=mid_channels)
  191. self.activ = get_activation_layer(mid_activation)
  192. if use_conv:
  193. self.conv2 = nn.Conv2D(
  194. in_channels=mid_channels,
  195. out_channels=channels,
  196. kernel_size=1,
  197. stride=1,
  198. groups=1,
  199. bias_attr=True,
  200. weight_attr=None)
  201. else:
  202. self.fc2 = nn.Linear(
  203. in_features=mid_channels, out_features=channels)
  204. self.sigmoid = get_activation_layer(out_activation)
  205. def forward(self, x):
  206. w = self.pool(x)
  207. if not self.use_conv:
  208. w = w.reshape(shape=[w.shape[0], -1])
  209. w = self.conv1(w) if self.use_conv else self.fc1(w)
  210. w = self.activ(w)
  211. w = self.conv2(w) if self.use_conv else self.fc2(w)
  212. w = self.sigmoid(w)
  213. if not self.use_conv:
  214. w = w.unsqueeze(2).unsqueeze(3)
  215. x = x * w
  216. return x
  217. class MixConv(nn.Layer):
  218. """
  219. Mixed convolution layer from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  220. https://arxiv.org/abs/1907.09595.
  221. Parameters:
  222. ----------
  223. in_channels : int
  224. Number of input channels.
  225. out_channels : int
  226. Number of output channels.
  227. kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  228. Convolution window size.
  229. stride : int or tuple/list of 2 int
  230. Strides of the convolution.
  231. padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  232. Padding value for convolution layer.
  233. dilation : int or tuple/list of 2 int, default 1
  234. Dilation value for convolution layer.
  235. groups : int, default 1
  236. Number of groups.
  237. bias : bool, default False
  238. Whether the layer uses a bias vector.
  239. axis : int, default 1
  240. The axis on which to concatenate the outputs.
  241. """
  242. def __init__(self,
  243. in_channels,
  244. out_channels,
  245. kernel_size,
  246. stride,
  247. padding,
  248. dilation=1,
  249. groups=1,
  250. bias=False,
  251. axis=1):
  252. super(MixConv, self).__init__()
  253. kernel_size = kernel_size if isinstance(kernel_size,
  254. list) else [kernel_size]
  255. padding = padding if isinstance(padding, list) else [padding]
  256. kernel_count = len(kernel_size)
  257. self.splitted_in_channels = self.split_channels(in_channels,
  258. kernel_count)
  259. splitted_out_channels = self.split_channels(out_channels, kernel_count)
  260. for i, kernel_size_i in enumerate(kernel_size):
  261. in_channels_i = self.splitted_in_channels[i]
  262. out_channels_i = splitted_out_channels[i]
  263. padding_i = padding[i]
  264. _ = self.add_sublayer(
  265. name=str(i),
  266. sublayer=nn.Conv2D(
  267. in_channels=in_channels_i,
  268. out_channels=out_channels_i,
  269. kernel_size=kernel_size_i,
  270. stride=stride,
  271. padding=padding_i,
  272. dilation=dilation,
  273. groups=(out_channels_i
  274. if out_channels == groups else groups),
  275. bias_attr=bias,
  276. weight_attr=None))
  277. self.axis = axis
  278. def forward(self, x):
  279. xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
  280. xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
  281. out = [
  282. conv_i(x_i) for x_i, conv_i in zip(xx, self._sub_layers.values())
  283. ]
  284. x = paddle.concat(tuple(out), axis=self.axis)
  285. return x
  286. @staticmethod
  287. def split_channels(channels, kernel_count):
  288. splitted_channels = [channels // kernel_count] * kernel_count
  289. splitted_channels[0] += channels - sum(splitted_channels)
  290. return splitted_channels
  291. class MixConvBlock(nn.Layer):
  292. """
  293. Mixed convolution block with Batch normalization and activation.
  294. Parameters:
  295. ----------
  296. in_channels : int
  297. Number of input channels.
  298. out_channels : int
  299. Number of output channels.
  300. kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  301. Convolution window size.
  302. stride : int or tuple/list of 2 int
  303. Strides of the convolution.
  304. padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  305. Padding value for convolution layer.
  306. dilation : int or tuple/list of 2 int, default 1
  307. Dilation value for convolution layer.
  308. groups : int, default 1
  309. Number of groups.
  310. bias : bool, default False
  311. Whether the layer uses a bias vector.
  312. use_bn : bool, default True
  313. Whether to use BatchNorm layer.
  314. bn_eps : float, default 1e-5
  315. Small float added to variance in Batch norm.
  316. activation : function or str or None, default nn.ReLU()
  317. Activation function or name of activation function.
  318. activate : bool, default True
  319. Whether activate the convolution block.
  320. """
  321. def __init__(self,
  322. in_channels,
  323. out_channels,
  324. kernel_size,
  325. stride,
  326. padding,
  327. dilation=1,
  328. groups=1,
  329. bias=False,
  330. use_bn=True,
  331. bn_eps=1e-5,
  332. activation=nn.ReLU()):
  333. super(MixConvBlock, self).__init__()
  334. self.activate = (activation is not None)
  335. self.use_bn = use_bn
  336. self.conv = MixConv(
  337. in_channels=in_channels,
  338. out_channels=out_channels,
  339. kernel_size=kernel_size,
  340. stride=stride,
  341. padding=padding,
  342. dilation=dilation,
  343. groups=groups,
  344. bias=bias)
  345. if self.use_bn:
  346. self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
  347. if self.activate:
  348. self.activ = get_activation_layer(activation)
  349. def forward(self, x):
  350. x = self.conv(x)
  351. if self.use_bn:
  352. x = self.bn(x)
  353. if self.activate:
  354. x = self.activ(x)
  355. return x
  356. def mixconv1x1_block(in_channels,
  357. out_channels,
  358. kernel_count,
  359. stride=1,
  360. groups=1,
  361. bias=False,
  362. use_bn=True,
  363. bn_eps=1e-5,
  364. activation=nn.ReLU()):
  365. """
  366. 1x1 version of the mixed convolution block.
  367. Parameters:
  368. ----------
  369. in_channels : int
  370. Number of input channels.
  371. out_channels : int
  372. Number of output channels.
  373. kernel_count : int
  374. Kernel count.
  375. stride : int or tuple/list of 2 int, default 1
  376. Strides of the convolution.
  377. groups : int, default 1
  378. Number of groups.
  379. bias : bool, default False
  380. Whether the layer uses a bias vector.
  381. use_bn : bool, default True
  382. Whether to use BatchNorm layer.
  383. bn_eps : float, default 1e-5
  384. Small float added to variance in Batch norm.
  385. activation : function or str, or None, default nn.ReLU()
  386. Activation function or name of activation function.
  387. """
  388. return MixConvBlock(
  389. in_channels=in_channels,
  390. out_channels=out_channels,
  391. kernel_size=([1] * kernel_count),
  392. stride=stride,
  393. padding=([0] * kernel_count),
  394. groups=groups,
  395. bias=bias,
  396. use_bn=use_bn,
  397. bn_eps=bn_eps,
  398. activation=activation)
  399. class MixUnit(nn.Layer):
  400. """
  401. MixNet unit.
  402. Parameters:
  403. ----------
  404. in_channels : int
  405. Number of input channels.
  406. out_channels : int
  407. Number of output channels. exp_channels : int
  408. Number of middle (expanded) channels.
  409. stride : int or tuple/list of 2 int
  410. Strides of the second convolution layer.
  411. exp_kernel_count : int
  412. Expansion convolution kernel count for each unit.
  413. conv1_kernel_count : int
  414. Conv1 kernel count for each unit.
  415. conv2_kernel_count : int
  416. Conv2 kernel count for each unit.
  417. exp_factor : int
  418. Expansion factor for each unit.
  419. se_factor : int
  420. SE reduction factor for each unit.
  421. activation : str
  422. Activation function or name of activation function.
  423. """
  424. def __init__(self, in_channels, out_channels, stride, exp_kernel_count,
  425. conv1_kernel_count, conv2_kernel_count, exp_factor, se_factor,
  426. activation):
  427. super(MixUnit, self).__init__()
  428. assert exp_factor >= 1
  429. assert se_factor >= 0
  430. self.residual = (in_channels == out_channels) and (stride == 1)
  431. self.use_se = se_factor > 0
  432. mid_channels = exp_factor * in_channels
  433. self.use_exp_conv = exp_factor > 1
  434. if self.use_exp_conv:
  435. if exp_kernel_count == 1:
  436. self.exp_conv = ConvBlock(
  437. in_channels=in_channels,
  438. out_channels=mid_channels,
  439. kernel_size=1,
  440. stride=1,
  441. padding=0,
  442. groups=1,
  443. bias=False,
  444. use_bn=True,
  445. bn_eps=1e-5,
  446. activation=activation)
  447. else:
  448. self.exp_conv = mixconv1x1_block(
  449. in_channels=in_channels,
  450. out_channels=mid_channels,
  451. kernel_count=exp_kernel_count,
  452. activation=activation)
  453. if conv1_kernel_count == 1:
  454. self.conv1 = ConvBlock(
  455. in_channels=mid_channels,
  456. out_channels=mid_channels,
  457. kernel_size=3,
  458. stride=stride,
  459. padding=1,
  460. dilation=1,
  461. groups=mid_channels,
  462. bias=False,
  463. use_bn=True,
  464. bn_eps=1e-5,
  465. activation=activation)
  466. else:
  467. self.conv1 = MixConvBlock(
  468. in_channels=mid_channels,
  469. out_channels=mid_channels,
  470. kernel_size=[3 + 2 * i for i in range(conv1_kernel_count)],
  471. stride=stride,
  472. padding=[1 + i for i in range(conv1_kernel_count)],
  473. groups=mid_channels,
  474. activation=activation)
  475. if self.use_se:
  476. self.se = SEBlock(
  477. channels=mid_channels,
  478. reduction=(exp_factor * se_factor),
  479. round_mid=False,
  480. mid_activation=activation)
  481. if conv2_kernel_count == 1:
  482. self.conv2 = ConvBlock(
  483. in_channels=mid_channels,
  484. out_channels=out_channels,
  485. activation=None,
  486. kernel_size=1,
  487. stride=1,
  488. padding=0,
  489. groups=1,
  490. bias=False,
  491. use_bn=True,
  492. bn_eps=1e-5)
  493. else:
  494. self.conv2 = mixconv1x1_block(
  495. in_channels=mid_channels,
  496. out_channels=out_channels,
  497. kernel_count=conv2_kernel_count,
  498. activation=None)
  499. def forward(self, x):
  500. if self.residual:
  501. identity = x
  502. if self.use_exp_conv:
  503. x = self.exp_conv(x)
  504. x = self.conv1(x)
  505. if self.use_se:
  506. x = self.se(x)
  507. x = self.conv2(x)
  508. if self.residual:
  509. x = x + identity
  510. return x
  511. class MixInitBlock(nn.Layer):
  512. """
  513. MixNet specific initial block.
  514. Parameters:
  515. ----------
  516. in_channels : int
  517. Number of input channels.
  518. out_channels : int
  519. Number of output channels.
  520. """
  521. def __init__(self, in_channels, out_channels):
  522. super(MixInitBlock, self).__init__()
  523. self.conv1 = ConvBlock(
  524. in_channels=in_channels,
  525. out_channels=out_channels,
  526. stride=2,
  527. kernel_size=3,
  528. padding=1)
  529. self.conv2 = MixUnit(
  530. in_channels=out_channels,
  531. out_channels=out_channels,
  532. stride=1,
  533. exp_kernel_count=1,
  534. conv1_kernel_count=1,
  535. conv2_kernel_count=1,
  536. exp_factor=1,
  537. se_factor=0,
  538. activation="relu")
  539. def forward(self, x):
  540. x = self.conv1(x)
  541. x = self.conv2(x)
  542. return x
  543. class MixNet(nn.Layer):
  544. """
  545. MixNet model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  546. https://arxiv.org/abs/1907.09595.
  547. Parameters:
  548. ----------
  549. channels : list of list of int
  550. Number of output channels for each unit.
  551. init_block_channels : int
  552. Number of output channels for the initial unit.
  553. final_block_channels : int
  554. Number of output channels for the final block of the feature extractor.
  555. exp_kernel_counts : list of list of int
  556. Expansion convolution kernel count for each unit.
  557. conv1_kernel_counts : list of list of int
  558. Conv1 kernel count for each unit.
  559. conv2_kernel_counts : list of list of int
  560. Conv2 kernel count for each unit.
  561. exp_factors : list of list of int
  562. Expansion factor for each unit.
  563. se_factors : list of list of int
  564. SE reduction factor for each unit.
  565. in_channels : int, default 3
  566. Number of input channels.
  567. in_size : tuple of two ints, default (224, 224)
  568. Spatial size of the expected input image.
  569. class_num : int, default 1000
  570. Number of classification classes.
  571. """
  572. def __init__(self,
  573. channels,
  574. init_block_channels,
  575. final_block_channels,
  576. exp_kernel_counts,
  577. conv1_kernel_counts,
  578. conv2_kernel_counts,
  579. exp_factors,
  580. se_factors,
  581. in_channels=3,
  582. in_size=(224, 224),
  583. class_num=1000):
  584. super(MixNet, self).__init__()
  585. self.in_size = in_size
  586. self.class_num = class_num
  587. self.features = nn.Sequential()
  588. self.features.add_sublayer(
  589. "init_block",
  590. MixInitBlock(
  591. in_channels=in_channels, out_channels=init_block_channels))
  592. in_channels = init_block_channels
  593. for i, channels_per_stage in enumerate(channels):
  594. stage = nn.Sequential()
  595. for j, out_channels in enumerate(channels_per_stage):
  596. stride = 2 if ((j == 0) and (i != 3)) or (
  597. (j == len(channels_per_stage) // 2) and (i == 3)) else 1
  598. exp_kernel_count = exp_kernel_counts[i][j]
  599. conv1_kernel_count = conv1_kernel_counts[i][j]
  600. conv2_kernel_count = conv2_kernel_counts[i][j]
  601. exp_factor = exp_factors[i][j]
  602. se_factor = se_factors[i][j]
  603. activation = "relu" if i == 0 else "swish"
  604. stage.add_sublayer(
  605. "unit{}".format(j + 1),
  606. MixUnit(
  607. in_channels=in_channels,
  608. out_channels=out_channels,
  609. stride=stride,
  610. exp_kernel_count=exp_kernel_count,
  611. conv1_kernel_count=conv1_kernel_count,
  612. conv2_kernel_count=conv2_kernel_count,
  613. exp_factor=exp_factor,
  614. se_factor=se_factor,
  615. activation=activation))
  616. in_channels = out_channels
  617. self.features.add_sublayer("stage{}".format(i + 1), stage)
  618. self.features.add_sublayer(
  619. "final_block",
  620. ConvBlock(
  621. in_channels=in_channels,
  622. out_channels=final_block_channels,
  623. kernel_size=1,
  624. stride=1,
  625. padding=0,
  626. groups=1,
  627. bias=False,
  628. use_bn=True,
  629. bn_eps=1e-5,
  630. activation=nn.ReLU()))
  631. in_channels = final_block_channels
  632. self.features.add_sublayer(
  633. "final_pool", nn.AvgPool2D(
  634. kernel_size=7, stride=1))
  635. self.output = nn.Linear(
  636. in_features=in_channels, out_features=class_num)
  637. def forward(self, x):
  638. x = self.features(x)
  639. reshape_dim = reduce(lambda x, y: x * y, x.shape[1:])
  640. x = x.reshape(shape=[x.shape[0], reshape_dim])
  641. x = self.output(x)
  642. return x
  643. def get_mixnet(version, width_scale, model_name=None, **kwargs):
  644. """
  645. Create MixNet model with specific parameters.
  646. Parameters:
  647. ----------
  648. version : str
  649. Version of MobileNetV3 ('s' or 'm').
  650. width_scale : float
  651. Scale factor for width of layers.
  652. model_name : str or None, default None
  653. Model name.
  654. """
  655. if version == "s":
  656. init_block_channels = 16
  657. channels = [[24, 24], [40, 40, 40, 40], [80, 80, 80],
  658. [120, 120, 120, 200, 200, 200]]
  659. exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 1, 1],
  660. [2, 2, 2, 1, 1, 1]]
  661. conv1_kernel_counts = [[1, 1], [3, 2, 2, 2], [3, 2, 2],
  662. [3, 4, 4, 5, 4, 4]]
  663. conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [2, 2, 2],
  664. [2, 2, 2, 1, 2, 2]]
  665. exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6], [6, 3, 3, 6, 6, 6]]
  666. se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4], [2, 2, 2, 2, 2, 2]]
  667. elif version == "m":
  668. init_block_channels = 24
  669. channels = [[32, 32], [40, 40, 40, 40], [80, 80, 80, 80],
  670. [120, 120, 120, 120, 200, 200, 200, 200]]
  671. exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
  672. [1, 2, 2, 2, 1, 1, 1, 1]]
  673. conv1_kernel_counts = [[3, 1], [4, 2, 2, 2], [3, 4, 4, 4],
  674. [1, 4, 4, 4, 4, 4, 4, 4]]
  675. conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
  676. [1, 2, 2, 2, 1, 2, 2, 2]]
  677. exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6, 6],
  678. [6, 3, 3, 3, 6, 6, 6, 6]]
  679. se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4, 4],
  680. [2, 2, 2, 2, 2, 2, 2, 2]]
  681. else:
  682. raise ValueError("Unsupported MixNet version {}".format(version))
  683. final_block_channels = 1536
  684. if width_scale != 1.0:
  685. channels = [[round_channels(cij * width_scale) for cij in ci]
  686. for ci in channels]
  687. init_block_channels = round_channels(init_block_channels * width_scale)
  688. net = MixNet(
  689. channels=channels,
  690. init_block_channels=init_block_channels,
  691. final_block_channels=final_block_channels,
  692. exp_kernel_counts=exp_kernel_counts,
  693. conv1_kernel_counts=conv1_kernel_counts,
  694. conv2_kernel_counts=conv2_kernel_counts,
  695. exp_factors=exp_factors,
  696. se_factors=se_factors,
  697. **kwargs)
  698. return net
  699. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  700. if pretrained is False:
  701. pass
  702. elif pretrained is True:
  703. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  704. elif isinstance(pretrained, str):
  705. load_dygraph_pretrain(model, pretrained)
  706. else:
  707. raise RuntimeError(
  708. "pretrained type is not available. Please use `string` or `boolean` type."
  709. )
  710. def MixNet_S(pretrained=False, use_ssld=False, **kwargs):
  711. """
  712. MixNet-S model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  713. https://arxiv.org/abs/1907.09595.
  714. """
  715. model = get_mixnet(
  716. version="s", width_scale=1.0, model_name="MixNet_S", **kwargs)
  717. _load_pretrained(
  718. pretrained, model, MODEL_URLS["MixNet_S"], use_ssld=use_ssld)
  719. return model
  720. def MixNet_M(pretrained=False, use_ssld=False, **kwargs):
  721. """
  722. MixNet-M model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  723. https://arxiv.org/abs/1907.09595.
  724. """
  725. model = get_mixnet(
  726. version="m", width_scale=1.0, model_name="MixNet_M", **kwargs)
  727. _load_pretrained(
  728. pretrained, model, MODEL_URLS["MixNet_M"], use_ssld=use_ssld)
  729. return model
  730. def MixNet_L(pretrained=False, use_ssld=False, **kwargs):
  731. """
  732. MixNet-S model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  733. https://arxiv.org/abs/1907.09595.
  734. """
  735. model = get_mixnet(
  736. version="m", width_scale=1.3, model_name="MixNet_L", **kwargs)
  737. _load_pretrained(
  738. pretrained, model, MODEL_URLS["MixNet_L"], use_ssld=use_ssld)
  739. return model