esnet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import, division, print_function
  15. import math
  16. import paddle
  17. from paddle import ParamAttr, reshape, transpose, concat, split
  18. import paddle.nn as nn
  19. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  20. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D
  21. from paddle.nn.initializer import KaimingNormal
  22. from paddle.regularizer import L2Decay
  23. from paddlex.ppcls.arch.backbone.base.theseus_layer import TheseusLayer
  24. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  25. MODEL_URLS = {
  26. "ESNet_x0_25":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams",
  28. "ESNet_x0_5":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_5_pretrained.pdparams",
  30. "ESNet_x0_75":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_75_pretrained.pdparams",
  32. "ESNet_x1_0":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x1_0_pretrained.pdparams",
  34. }
  35. __all__ = list(MODEL_URLS.keys())
  36. def channel_shuffle(x, groups):
  37. batch_size, num_channels, height, width = x.shape[0:4]
  38. channels_per_group = num_channels // groups
  39. x = reshape(
  40. x=x, shape=[batch_size, groups, channels_per_group, height, width])
  41. x = transpose(x=x, perm=[0, 2, 1, 3, 4])
  42. x = reshape(x=x, shape=[batch_size, num_channels, height, width])
  43. return x
  44. def make_divisible(v, divisor=8, min_value=None):
  45. if min_value is None:
  46. min_value = divisor
  47. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  48. if new_v < 0.9 * v:
  49. new_v += divisor
  50. return new_v
  51. class ConvBNLayer(TheseusLayer):
  52. def __init__(self,
  53. in_channels,
  54. out_channels,
  55. kernel_size,
  56. stride=1,
  57. groups=1,
  58. if_act=True):
  59. super().__init__()
  60. self.conv = Conv2D(
  61. in_channels=in_channels,
  62. out_channels=out_channels,
  63. kernel_size=kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. groups=groups,
  67. weight_attr=ParamAttr(initializer=KaimingNormal()),
  68. bias_attr=False)
  69. self.bn = BatchNorm(
  70. out_channels,
  71. param_attr=ParamAttr(regularizer=L2Decay(0.0)),
  72. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  73. self.if_act = if_act
  74. self.hardswish = nn.Hardswish()
  75. def forward(self, x):
  76. x = self.conv(x)
  77. x = self.bn(x)
  78. if self.if_act:
  79. x = self.hardswish(x)
  80. return x
  81. class SEModule(TheseusLayer):
  82. def __init__(self, channel, reduction=4):
  83. super().__init__()
  84. self.avg_pool = AdaptiveAvgPool2D(1)
  85. self.conv1 = Conv2D(
  86. in_channels=channel,
  87. out_channels=channel // reduction,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0)
  91. self.relu = nn.ReLU()
  92. self.conv2 = Conv2D(
  93. in_channels=channel // reduction,
  94. out_channels=channel,
  95. kernel_size=1,
  96. stride=1,
  97. padding=0)
  98. self.hardsigmoid = nn.Hardsigmoid()
  99. def forward(self, x):
  100. identity = x
  101. x = self.avg_pool(x)
  102. x = self.conv1(x)
  103. x = self.relu(x)
  104. x = self.conv2(x)
  105. x = self.hardsigmoid(x)
  106. x = paddle.multiply(x=identity, y=x)
  107. return x
  108. class ESBlock1(TheseusLayer):
  109. def __init__(self, in_channels, out_channels):
  110. super().__init__()
  111. self.pw_1_1 = ConvBNLayer(
  112. in_channels=in_channels // 2,
  113. out_channels=out_channels // 2,
  114. kernel_size=1,
  115. stride=1)
  116. self.dw_1 = ConvBNLayer(
  117. in_channels=out_channels // 2,
  118. out_channels=out_channels // 2,
  119. kernel_size=3,
  120. stride=1,
  121. groups=out_channels // 2,
  122. if_act=False)
  123. self.se = SEModule(out_channels)
  124. self.pw_1_2 = ConvBNLayer(
  125. in_channels=out_channels,
  126. out_channels=out_channels // 2,
  127. kernel_size=1,
  128. stride=1)
  129. def forward(self, x):
  130. x1, x2 = split(
  131. x, num_or_sections=[x.shape[1] // 2, x.shape[1] // 2], axis=1)
  132. x2 = self.pw_1_1(x2)
  133. x3 = self.dw_1(x2)
  134. x3 = concat([x2, x3], axis=1)
  135. x3 = self.se(x3)
  136. x3 = self.pw_1_2(x3)
  137. x = concat([x1, x3], axis=1)
  138. return channel_shuffle(x, 2)
  139. class ESBlock2(TheseusLayer):
  140. def __init__(self, in_channels, out_channels):
  141. super().__init__()
  142. # branch1
  143. self.dw_1 = ConvBNLayer(
  144. in_channels=in_channels,
  145. out_channels=in_channels,
  146. kernel_size=3,
  147. stride=2,
  148. groups=in_channels,
  149. if_act=False)
  150. self.pw_1 = ConvBNLayer(
  151. in_channels=in_channels,
  152. out_channels=out_channels // 2,
  153. kernel_size=1,
  154. stride=1)
  155. # branch2
  156. self.pw_2_1 = ConvBNLayer(
  157. in_channels=in_channels,
  158. out_channels=out_channels // 2,
  159. kernel_size=1)
  160. self.dw_2 = ConvBNLayer(
  161. in_channels=out_channels // 2,
  162. out_channels=out_channels // 2,
  163. kernel_size=3,
  164. stride=2,
  165. groups=out_channels // 2,
  166. if_act=False)
  167. self.se = SEModule(out_channels // 2)
  168. self.pw_2_2 = ConvBNLayer(
  169. in_channels=out_channels // 2,
  170. out_channels=out_channels // 2,
  171. kernel_size=1)
  172. self.concat_dw = ConvBNLayer(
  173. in_channels=out_channels,
  174. out_channels=out_channels,
  175. kernel_size=3,
  176. groups=out_channels)
  177. self.concat_pw = ConvBNLayer(
  178. in_channels=out_channels, out_channels=out_channels, kernel_size=1)
  179. def forward(self, x):
  180. x1 = self.dw_1(x)
  181. x1 = self.pw_1(x1)
  182. x2 = self.pw_2_1(x)
  183. x2 = self.dw_2(x2)
  184. x2 = self.se(x2)
  185. x2 = self.pw_2_2(x2)
  186. x = concat([x1, x2], axis=1)
  187. x = self.concat_dw(x)
  188. x = self.concat_pw(x)
  189. return x
  190. class ESNet(TheseusLayer):
  191. def __init__(self,
  192. class_num=1000,
  193. scale=1.0,
  194. dropout_prob=0.2,
  195. class_expand=1280):
  196. super().__init__()
  197. self.scale = scale
  198. self.class_num = class_num
  199. self.class_expand = class_expand
  200. stage_repeats = [3, 7, 3]
  201. stage_out_channels = [
  202. -1, 24, make_divisible(116 * scale), make_divisible(232 * scale),
  203. make_divisible(464 * scale), 1024
  204. ]
  205. self.conv1 = ConvBNLayer(
  206. in_channels=3,
  207. out_channels=stage_out_channels[1],
  208. kernel_size=3,
  209. stride=2)
  210. self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
  211. block_list = []
  212. for stage_id, num_repeat in enumerate(stage_repeats):
  213. for i in range(num_repeat):
  214. if i == 0:
  215. block = ESBlock2(
  216. in_channels=stage_out_channels[stage_id + 1],
  217. out_channels=stage_out_channels[stage_id + 2])
  218. else:
  219. block = ESBlock1(
  220. in_channels=stage_out_channels[stage_id + 2],
  221. out_channels=stage_out_channels[stage_id + 2])
  222. block_list.append(block)
  223. self.blocks = nn.Sequential(*block_list)
  224. self.conv2 = ConvBNLayer(
  225. in_channels=stage_out_channels[-2],
  226. out_channels=stage_out_channels[-1],
  227. kernel_size=1)
  228. self.avg_pool = AdaptiveAvgPool2D(1)
  229. self.last_conv = Conv2D(
  230. in_channels=stage_out_channels[-1],
  231. out_channels=self.class_expand,
  232. kernel_size=1,
  233. stride=1,
  234. padding=0,
  235. bias_attr=False)
  236. self.hardswish = nn.Hardswish()
  237. self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
  238. self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
  239. self.fc = Linear(self.class_expand, self.class_num)
  240. def forward(self, x):
  241. x = self.conv1(x)
  242. x = self.max_pool(x)
  243. x = self.blocks(x)
  244. x = self.conv2(x)
  245. x = self.avg_pool(x)
  246. x = self.last_conv(x)
  247. x = self.hardswish(x)
  248. x = self.dropout(x)
  249. x = self.flatten(x)
  250. x = self.fc(x)
  251. return x
  252. def _load_pretrained(pretrained, model, model_url, use_ssld):
  253. if pretrained is False:
  254. pass
  255. elif pretrained is True:
  256. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  257. elif isinstance(pretrained, str):
  258. load_dygraph_pretrain(model, pretrained)
  259. else:
  260. raise RuntimeError(
  261. "pretrained type is not available. Please use `string` or `boolean` type."
  262. )
  263. def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
  264. """
  265. ESNet_x0_25
  266. Args:
  267. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  268. If str, means the path of the pretrained model.
  269. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  270. Returns:
  271. model: nn.Layer. Specific `ESNet_x0_25` model depends on args.
  272. """
  273. model = ESNet(scale=0.25, **kwargs)
  274. _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_25"], use_ssld)
  275. return model
  276. def ESNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
  277. """
  278. ESNet_x0_5
  279. Args:
  280. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  281. If str, means the path of the pretrained model.
  282. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  283. Returns:
  284. model: nn.Layer. Specific `ESNet_x0_5` model depends on args.
  285. """
  286. model = ESNet(scale=0.5, **kwargs)
  287. _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_5"], use_ssld)
  288. return model
  289. def ESNet_x0_75(pretrained=False, use_ssld=False, **kwargs):
  290. """
  291. ESNet_x0_75
  292. Args:
  293. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  294. If str, means the path of the pretrained model.
  295. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  296. Returns:
  297. model: nn.Layer. Specific `ESNet_x0_75` model depends on args.
  298. """
  299. model = ESNet(scale=0.75, **kwargs)
  300. _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_75"], use_ssld)
  301. return model
  302. def ESNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
  303. """
  304. ESNet_x1_0
  305. Args:
  306. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  307. If str, means the path of the pretrained model.
  308. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  309. Returns:
  310. model: nn.Layer. Specific `ESNet_x1_0` model depends on args.
  311. """
  312. model = ESNet(scale=1.0, **kwargs)
  313. _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x1_0"], use_ssld)
  314. return model