ghostnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. from paddle import ParamAttr
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn import Conv2D, BatchNorm, AdaptiveAvgPool2D, Linear
  20. from paddle.regularizer import L2Decay
  21. from paddle.nn.initializer import Uniform, KaimingNormal
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. from numbers import Integral
  24. from ..shape_spec import ShapeSpec
  25. from .mobilenet_v3 import make_divisible, ConvBNLayer
  26. __all__ = ['GhostNet']
  27. class ExtraBlockDW(nn.Layer):
  28. def __init__(self,
  29. in_c,
  30. ch_1,
  31. ch_2,
  32. stride,
  33. lr_mult,
  34. conv_decay=0.,
  35. norm_type='bn',
  36. norm_decay=0.,
  37. freeze_norm=False,
  38. name=None):
  39. super(ExtraBlockDW, self).__init__()
  40. self.pointwise_conv = ConvBNLayer(
  41. in_c=in_c,
  42. out_c=ch_1,
  43. filter_size=1,
  44. stride=1,
  45. padding=0,
  46. act='relu6',
  47. lr_mult=lr_mult,
  48. conv_decay=conv_decay,
  49. norm_type=norm_type,
  50. norm_decay=norm_decay,
  51. freeze_norm=freeze_norm,
  52. name=name + "_extra1")
  53. self.depthwise_conv = ConvBNLayer(
  54. in_c=ch_1,
  55. out_c=ch_2,
  56. filter_size=3,
  57. stride=stride,
  58. padding=1, #
  59. num_groups=int(ch_1),
  60. act='relu6',
  61. lr_mult=lr_mult,
  62. conv_decay=conv_decay,
  63. norm_type=norm_type,
  64. norm_decay=norm_decay,
  65. freeze_norm=freeze_norm,
  66. name=name + "_extra2_dw")
  67. self.normal_conv = ConvBNLayer(
  68. in_c=ch_2,
  69. out_c=ch_2,
  70. filter_size=1,
  71. stride=1,
  72. padding=0,
  73. act='relu6',
  74. lr_mult=lr_mult,
  75. conv_decay=conv_decay,
  76. norm_type=norm_type,
  77. norm_decay=norm_decay,
  78. freeze_norm=freeze_norm,
  79. name=name + "_extra2_sep")
  80. def forward(self, inputs):
  81. x = self.pointwise_conv(inputs)
  82. x = self.depthwise_conv(x)
  83. x = self.normal_conv(x)
  84. return x
  85. class SEBlock(nn.Layer):
  86. def __init__(self, num_channels, lr_mult, reduction_ratio=4, name=None):
  87. super(SEBlock, self).__init__()
  88. self.pool2d_gap = AdaptiveAvgPool2D(1)
  89. self._num_channels = num_channels
  90. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  91. med_ch = num_channels // reduction_ratio
  92. self.squeeze = Linear(
  93. num_channels,
  94. med_ch,
  95. weight_attr=ParamAttr(
  96. learning_rate=lr_mult,
  97. initializer=Uniform(-stdv, stdv),
  98. name=name + "_1_weights"),
  99. bias_attr=ParamAttr(
  100. learning_rate=lr_mult, name=name + "_1_offset"))
  101. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  102. self.excitation = Linear(
  103. med_ch,
  104. num_channels,
  105. weight_attr=ParamAttr(
  106. learning_rate=lr_mult,
  107. initializer=Uniform(-stdv, stdv),
  108. name=name + "_2_weights"),
  109. bias_attr=ParamAttr(
  110. learning_rate=lr_mult, name=name + "_2_offset"))
  111. def forward(self, inputs):
  112. pool = self.pool2d_gap(inputs)
  113. pool = paddle.squeeze(pool, axis=[2, 3])
  114. squeeze = self.squeeze(pool)
  115. squeeze = F.relu(squeeze)
  116. excitation = self.excitation(squeeze)
  117. excitation = paddle.clip(x=excitation, min=0, max=1)
  118. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  119. out = paddle.multiply(inputs, excitation)
  120. return out
  121. class GhostModule(nn.Layer):
  122. def __init__(self,
  123. in_channels,
  124. output_channels,
  125. kernel_size=1,
  126. ratio=2,
  127. dw_size=3,
  128. stride=1,
  129. relu=True,
  130. lr_mult=1.,
  131. conv_decay=0.,
  132. norm_type='bn',
  133. norm_decay=0.,
  134. freeze_norm=False,
  135. name=None):
  136. super(GhostModule, self).__init__()
  137. init_channels = int(math.ceil(output_channels / ratio))
  138. new_channels = int(init_channels * (ratio - 1))
  139. self.primary_conv = ConvBNLayer(
  140. in_c=in_channels,
  141. out_c=init_channels,
  142. filter_size=kernel_size,
  143. stride=stride,
  144. padding=int((kernel_size - 1) // 2),
  145. num_groups=1,
  146. act="relu" if relu else None,
  147. lr_mult=lr_mult,
  148. conv_decay=conv_decay,
  149. norm_type=norm_type,
  150. norm_decay=norm_decay,
  151. freeze_norm=freeze_norm,
  152. name=name + "_primary_conv")
  153. self.cheap_operation = ConvBNLayer(
  154. in_c=init_channels,
  155. out_c=new_channels,
  156. filter_size=dw_size,
  157. stride=1,
  158. padding=int((dw_size - 1) // 2),
  159. num_groups=init_channels,
  160. act="relu" if relu else None,
  161. lr_mult=lr_mult,
  162. conv_decay=conv_decay,
  163. norm_type=norm_type,
  164. norm_decay=norm_decay,
  165. freeze_norm=freeze_norm,
  166. name=name + "_cheap_operation")
  167. def forward(self, inputs):
  168. x = self.primary_conv(inputs)
  169. y = self.cheap_operation(x)
  170. out = paddle.concat([x, y], axis=1)
  171. return out
  172. class GhostBottleneck(nn.Layer):
  173. def __init__(self,
  174. in_channels,
  175. hidden_dim,
  176. output_channels,
  177. kernel_size,
  178. stride,
  179. use_se,
  180. lr_mult,
  181. conv_decay=0.,
  182. norm_type='bn',
  183. norm_decay=0.,
  184. freeze_norm=False,
  185. return_list=False,
  186. name=None):
  187. super(GhostBottleneck, self).__init__()
  188. self._stride = stride
  189. self._use_se = use_se
  190. self._num_channels = in_channels
  191. self._output_channels = output_channels
  192. self.return_list = return_list
  193. self.ghost_module_1 = GhostModule(
  194. in_channels=in_channels,
  195. output_channels=hidden_dim,
  196. kernel_size=1,
  197. stride=1,
  198. relu=True,
  199. lr_mult=lr_mult,
  200. conv_decay=conv_decay,
  201. norm_type=norm_type,
  202. norm_decay=norm_decay,
  203. freeze_norm=freeze_norm,
  204. name=name + "_ghost_module_1")
  205. if stride == 2:
  206. self.depthwise_conv = ConvBNLayer(
  207. in_c=hidden_dim,
  208. out_c=hidden_dim,
  209. filter_size=kernel_size,
  210. stride=stride,
  211. padding=int((kernel_size - 1) // 2),
  212. num_groups=hidden_dim,
  213. act=None,
  214. lr_mult=lr_mult,
  215. conv_decay=conv_decay,
  216. norm_type=norm_type,
  217. norm_decay=norm_decay,
  218. freeze_norm=freeze_norm,
  219. name=name +
  220. "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  221. )
  222. if use_se:
  223. self.se_block = SEBlock(hidden_dim, lr_mult, name=name + "_se")
  224. self.ghost_module_2 = GhostModule(
  225. in_channels=hidden_dim,
  226. output_channels=output_channels,
  227. kernel_size=1,
  228. relu=False,
  229. lr_mult=lr_mult,
  230. conv_decay=conv_decay,
  231. norm_type=norm_type,
  232. norm_decay=norm_decay,
  233. freeze_norm=freeze_norm,
  234. name=name + "_ghost_module_2")
  235. if stride != 1 or in_channels != output_channels:
  236. self.shortcut_depthwise = ConvBNLayer(
  237. in_c=in_channels,
  238. out_c=in_channels,
  239. filter_size=kernel_size,
  240. stride=stride,
  241. padding=int((kernel_size - 1) // 2),
  242. num_groups=in_channels,
  243. act=None,
  244. lr_mult=lr_mult,
  245. conv_decay=conv_decay,
  246. norm_type=norm_type,
  247. norm_decay=norm_decay,
  248. freeze_norm=freeze_norm,
  249. name=name +
  250. "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  251. )
  252. self.shortcut_conv = ConvBNLayer(
  253. in_c=in_channels,
  254. out_c=output_channels,
  255. filter_size=1,
  256. stride=1,
  257. padding=0,
  258. num_groups=1,
  259. act=None,
  260. lr_mult=lr_mult,
  261. conv_decay=conv_decay,
  262. norm_type=norm_type,
  263. norm_decay=norm_decay,
  264. freeze_norm=freeze_norm,
  265. name=name + "_shortcut_conv")
  266. def forward(self, inputs):
  267. y = self.ghost_module_1(inputs)
  268. x = y
  269. if self._stride == 2:
  270. x = self.depthwise_conv(x)
  271. if self._use_se:
  272. x = self.se_block(x)
  273. x = self.ghost_module_2(x)
  274. if self._stride == 1 and self._num_channels == self._output_channels:
  275. shortcut = inputs
  276. else:
  277. shortcut = self.shortcut_depthwise(inputs)
  278. shortcut = self.shortcut_conv(shortcut)
  279. x = paddle.add(x=x, y=shortcut)
  280. if self.return_list:
  281. return [y, x]
  282. else:
  283. return x
  284. @register
  285. @serializable
  286. class GhostNet(nn.Layer):
  287. __shared__ = ['norm_type']
  288. def __init__(self,
  289. scale=1.3,
  290. feature_maps=[6, 12, 15],
  291. with_extra_blocks=False,
  292. extra_block_filters=[[256, 512], [128, 256], [128, 256],
  293. [64, 128]],
  294. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  295. conv_decay=0.,
  296. norm_type='bn',
  297. norm_decay=0.0,
  298. freeze_norm=False):
  299. super(GhostNet, self).__init__()
  300. if isinstance(feature_maps, Integral):
  301. feature_maps = [feature_maps]
  302. if norm_type == 'sync_bn' and freeze_norm:
  303. raise ValueError(
  304. "The norm_type should not be sync_bn when freeze_norm is True")
  305. self.feature_maps = feature_maps
  306. self.with_extra_blocks = with_extra_blocks
  307. self.extra_block_filters = extra_block_filters
  308. inplanes = 16
  309. self.cfgs = [
  310. # k, t, c, SE, s
  311. [3, 16, 16, 0, 1],
  312. [3, 48, 24, 0, 2],
  313. [3, 72, 24, 0, 1],
  314. [5, 72, 40, 1, 2],
  315. [5, 120, 40, 1, 1],
  316. [3, 240, 80, 0, 2],
  317. [3, 200, 80, 0, 1],
  318. [3, 184, 80, 0, 1],
  319. [3, 184, 80, 0, 1],
  320. [3, 480, 112, 1, 1],
  321. [3, 672, 112, 1, 1],
  322. [5, 672, 160, 1, 2], # SSDLite output
  323. [5, 960, 160, 0, 1],
  324. [5, 960, 160, 1, 1],
  325. [5, 960, 160, 0, 1],
  326. [5, 960, 160, 1, 1]
  327. ]
  328. self.scale = scale
  329. conv1_out_ch = int(make_divisible(inplanes * self.scale, 4))
  330. self.conv1 = ConvBNLayer(
  331. in_c=3,
  332. out_c=conv1_out_ch,
  333. filter_size=3,
  334. stride=2,
  335. padding=1,
  336. num_groups=1,
  337. act="relu",
  338. lr_mult=1.,
  339. conv_decay=conv_decay,
  340. norm_type=norm_type,
  341. norm_decay=norm_decay,
  342. freeze_norm=freeze_norm,
  343. name="conv1")
  344. # build inverted residual blocks
  345. self._out_channels = []
  346. self.ghost_bottleneck_list = []
  347. idx = 0
  348. inplanes = conv1_out_ch
  349. for k, exp_size, c, use_se, s in self.cfgs:
  350. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  351. lr_mult = lr_mult_list[lr_idx]
  352. # for SSD/SSDLite, first head input is after ResidualUnit expand_conv
  353. return_list = self.with_extra_blocks and idx + 2 in self.feature_maps
  354. ghost_bottleneck = self.add_sublayer(
  355. "_ghostbottleneck_" + str(idx),
  356. sublayer=GhostBottleneck(
  357. in_channels=inplanes,
  358. hidden_dim=int(make_divisible(exp_size * self.scale, 4)),
  359. output_channels=int(make_divisible(c * self.scale, 4)),
  360. kernel_size=k,
  361. stride=s,
  362. use_se=use_se,
  363. lr_mult=lr_mult,
  364. conv_decay=conv_decay,
  365. norm_type=norm_type,
  366. norm_decay=norm_decay,
  367. freeze_norm=freeze_norm,
  368. return_list=return_list,
  369. name="_ghostbottleneck_" + str(idx)))
  370. self.ghost_bottleneck_list.append(ghost_bottleneck)
  371. inplanes = int(make_divisible(c * self.scale, 4))
  372. idx += 1
  373. self._update_out_channels(
  374. int(make_divisible(exp_size * self.scale, 4))
  375. if return_list else inplanes, idx + 1, feature_maps)
  376. if self.with_extra_blocks:
  377. self.extra_block_list = []
  378. extra_out_c = int(make_divisible(self.scale * self.cfgs[-1][1], 4))
  379. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  380. lr_mult = lr_mult_list[lr_idx]
  381. conv_extra = self.add_sublayer(
  382. "conv" + str(idx + 2),
  383. sublayer=ConvBNLayer(
  384. in_c=inplanes,
  385. out_c=extra_out_c,
  386. filter_size=1,
  387. stride=1,
  388. padding=0,
  389. num_groups=1,
  390. act="relu6",
  391. lr_mult=lr_mult,
  392. conv_decay=conv_decay,
  393. norm_type=norm_type,
  394. norm_decay=norm_decay,
  395. freeze_norm=freeze_norm,
  396. name="conv" + str(idx + 2)))
  397. self.extra_block_list.append(conv_extra)
  398. idx += 1
  399. self._update_out_channels(extra_out_c, idx + 1, feature_maps)
  400. for j, block_filter in enumerate(self.extra_block_filters):
  401. in_c = extra_out_c if j == 0 else self.extra_block_filters[
  402. j - 1][1]
  403. conv_extra = self.add_sublayer(
  404. "conv" + str(idx + 2),
  405. sublayer=ExtraBlockDW(
  406. in_c,
  407. block_filter[0],
  408. block_filter[1],
  409. stride=2,
  410. lr_mult=lr_mult,
  411. conv_decay=conv_decay,
  412. norm_type=norm_type,
  413. norm_decay=norm_decay,
  414. freeze_norm=freeze_norm,
  415. name='conv' + str(idx + 2)))
  416. self.extra_block_list.append(conv_extra)
  417. idx += 1
  418. self._update_out_channels(block_filter[1], idx + 1,
  419. feature_maps)
  420. def _update_out_channels(self, channel, feature_idx, feature_maps):
  421. if feature_idx in feature_maps:
  422. self._out_channels.append(channel)
  423. def forward(self, inputs):
  424. x = self.conv1(inputs['image'])
  425. outs = []
  426. for idx, ghost_bottleneck in enumerate(self.ghost_bottleneck_list):
  427. x = ghost_bottleneck(x)
  428. if idx + 2 in self.feature_maps:
  429. if isinstance(x, list):
  430. outs.append(x[0])
  431. x = x[1]
  432. else:
  433. outs.append(x)
  434. if not self.with_extra_blocks:
  435. return outs
  436. for i, block in enumerate(self.extra_block_list):
  437. idx = i + len(self.ghost_bottleneck_list)
  438. x = block(x)
  439. if idx + 2 in self.feature_maps:
  440. outs.append(x)
  441. return outs
  442. @property
  443. def out_shape(self):
  444. return [ShapeSpec(channels=c) for c in self._out_channels]