se_resnext_vd.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  27. MODEL_URLS = {
  28. "SE_ResNeXt50_vd_32x4d":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SE_ResNeXt50_vd_32x4d_pretrained.pdparams",
  30. "SENet154_vd":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SENet154_vd_pretrained.pdparams",
  32. }
  33. __all__ = list(MODEL_URLS.keys())
  34. class ConvBNLayer(nn.Layer):
  35. def __init__(self,
  36. num_channels,
  37. num_filters,
  38. filter_size,
  39. stride=1,
  40. groups=1,
  41. is_vd_mode=False,
  42. act=None,
  43. name=None):
  44. super(ConvBNLayer, self).__init__()
  45. self.is_vd_mode = is_vd_mode
  46. self._pool2d_avg = AvgPool2D(
  47. kernel_size=2, stride=2, padding=0, ceil_mode=True)
  48. self._conv = Conv2D(
  49. in_channels=num_channels,
  50. out_channels=num_filters,
  51. kernel_size=filter_size,
  52. stride=stride,
  53. padding=(filter_size - 1) // 2,
  54. groups=groups,
  55. weight_attr=ParamAttr(name=name + "_weights"),
  56. bias_attr=False)
  57. bn_name = name + '_bn'
  58. self._batch_norm = BatchNorm(
  59. num_filters,
  60. act=act,
  61. param_attr=ParamAttr(name=bn_name + '_scale'),
  62. bias_attr=ParamAttr(bn_name + '_offset'),
  63. moving_mean_name=bn_name + '_mean',
  64. moving_variance_name=bn_name + '_variance')
  65. def forward(self, inputs):
  66. if self.is_vd_mode:
  67. inputs = self._pool2d_avg(inputs)
  68. y = self._conv(inputs)
  69. y = self._batch_norm(y)
  70. return y
  71. class BottleneckBlock(nn.Layer):
  72. def __init__(self,
  73. num_channels,
  74. num_filters,
  75. stride,
  76. cardinality,
  77. reduction_ratio,
  78. shortcut=True,
  79. if_first=False,
  80. name=None):
  81. super(BottleneckBlock, self).__init__()
  82. self.conv0 = ConvBNLayer(
  83. num_channels=num_channels,
  84. num_filters=num_filters,
  85. filter_size=1,
  86. act='relu',
  87. name='conv' + name + '_x1')
  88. self.conv1 = ConvBNLayer(
  89. num_channels=num_filters,
  90. num_filters=num_filters,
  91. filter_size=3,
  92. groups=cardinality,
  93. stride=stride,
  94. act='relu',
  95. name='conv' + name + '_x2')
  96. self.conv2 = ConvBNLayer(
  97. num_channels=num_filters,
  98. num_filters=num_filters * 2 if cardinality == 32 else num_filters,
  99. filter_size=1,
  100. act=None,
  101. name='conv' + name + '_x3')
  102. self.scale = SELayer(
  103. num_channels=num_filters * 2 if cardinality == 32 else num_filters,
  104. num_filters=num_filters * 2 if cardinality == 32 else num_filters,
  105. reduction_ratio=reduction_ratio,
  106. name='fc' + name)
  107. if not shortcut:
  108. self.short = ConvBNLayer(
  109. num_channels=num_channels,
  110. num_filters=num_filters * 2
  111. if cardinality == 32 else num_filters,
  112. filter_size=1,
  113. stride=1,
  114. is_vd_mode=False if if_first else True,
  115. name='conv' + name + '_prj')
  116. self.shortcut = shortcut
  117. def forward(self, inputs):
  118. y = self.conv0(inputs)
  119. conv1 = self.conv1(y)
  120. conv2 = self.conv2(conv1)
  121. scale = self.scale(conv2)
  122. if self.shortcut:
  123. short = inputs
  124. else:
  125. short = self.short(inputs)
  126. y = paddle.add(x=short, y=scale)
  127. y = F.relu(y)
  128. return y
  129. class SELayer(nn.Layer):
  130. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  131. super(SELayer, self).__init__()
  132. self.pool2d_gap = AdaptiveAvgPool2D(1)
  133. self._num_channels = num_channels
  134. med_ch = int(num_channels / reduction_ratio)
  135. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  136. self.squeeze = Linear(
  137. num_channels,
  138. med_ch,
  139. weight_attr=ParamAttr(
  140. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  141. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  142. self.relu = nn.ReLU()
  143. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  144. self.excitation = Linear(
  145. med_ch,
  146. num_filters,
  147. weight_attr=ParamAttr(
  148. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  149. bias_attr=ParamAttr(name=name + '_exc_offset'))
  150. self.sigmoid = nn.Sigmoid()
  151. def forward(self, input):
  152. pool = self.pool2d_gap(input)
  153. pool = paddle.squeeze(pool, axis=[2, 3])
  154. squeeze = self.squeeze(pool)
  155. squeeze = self.relu(squeeze)
  156. excitation = self.excitation(squeeze)
  157. excitation = self.sigmoid(excitation)
  158. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  159. out = paddle.multiply(input, excitation)
  160. return out
  161. class ResNeXt(nn.Layer):
  162. def __init__(self, layers=50, class_num=1000, cardinality=32):
  163. super(ResNeXt, self).__init__()
  164. self.layers = layers
  165. self.cardinality = cardinality
  166. self.reduction_ratio = 16
  167. supported_layers = [50, 101, 152]
  168. assert layers in supported_layers, \
  169. "supported layers are {} but input layer is {}".format(
  170. supported_layers, layers)
  171. supported_cardinality = [32, 64]
  172. assert cardinality in supported_cardinality, \
  173. "supported cardinality is {} but input cardinality is {}" \
  174. .format(supported_cardinality, cardinality)
  175. if layers == 50:
  176. depth = [3, 4, 6, 3]
  177. elif layers == 101:
  178. depth = [3, 4, 23, 3]
  179. elif layers == 152:
  180. depth = [3, 8, 36, 3]
  181. num_channels = [128, 256, 512, 1024]
  182. num_filters = [128, 256, 512,
  183. 1024] if cardinality == 32 else [256, 512, 1024, 2048]
  184. self.conv1_1 = ConvBNLayer(
  185. num_channels=3,
  186. num_filters=64,
  187. filter_size=3,
  188. stride=2,
  189. act='relu',
  190. name="conv1_1")
  191. self.conv1_2 = ConvBNLayer(
  192. num_channels=64,
  193. num_filters=64,
  194. filter_size=3,
  195. stride=1,
  196. act='relu',
  197. name="conv1_2")
  198. self.conv1_3 = ConvBNLayer(
  199. num_channels=64,
  200. num_filters=128,
  201. filter_size=3,
  202. stride=1,
  203. act='relu',
  204. name="conv1_3")
  205. self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
  206. self.block_list = []
  207. n = 1 if layers == 50 or layers == 101 else 3
  208. for block in range(len(depth)):
  209. n += 1
  210. shortcut = False
  211. for i in range(depth[block]):
  212. bottleneck_block = self.add_sublayer(
  213. 'bb_%d_%d' % (block, i),
  214. BottleneckBlock(
  215. num_channels=num_channels[block] if i == 0 else
  216. num_filters[block] * int(64 // self.cardinality),
  217. num_filters=num_filters[block],
  218. stride=2 if i == 0 and block != 0 else 1,
  219. cardinality=self.cardinality,
  220. reduction_ratio=self.reduction_ratio,
  221. shortcut=shortcut,
  222. if_first=block == 0,
  223. name=str(n) + '_' + str(i + 1)))
  224. self.block_list.append(bottleneck_block)
  225. shortcut = True
  226. self.pool2d_avg = AdaptiveAvgPool2D(1)
  227. self.pool2d_avg_channels = num_channels[-1] * 2
  228. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  229. self.out = Linear(
  230. self.pool2d_avg_channels,
  231. class_num,
  232. weight_attr=ParamAttr(
  233. initializer=Uniform(-stdv, stdv), name="fc6_weights"),
  234. bias_attr=ParamAttr(name="fc6_offset"))
  235. def forward(self, inputs):
  236. y = self.conv1_1(inputs)
  237. y = self.conv1_2(y)
  238. y = self.conv1_3(y)
  239. y = self.pool2d_max(y)
  240. for block in self.block_list:
  241. y = block(y)
  242. y = self.pool2d_avg(y)
  243. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  244. y = self.out(y)
  245. return y
  246. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  247. if pretrained is False:
  248. pass
  249. elif pretrained is True:
  250. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  251. elif isinstance(pretrained, str):
  252. load_dygraph_pretrain(model, pretrained)
  253. else:
  254. raise RuntimeError(
  255. "pretrained type is not available. Please use `string` or `boolean` type."
  256. )
  257. def SE_ResNeXt50_vd_32x4d(pretrained=False, use_ssld=False, **kwargs):
  258. model = ResNeXt(layers=50, cardinality=32, **kwargs)
  259. _load_pretrained(
  260. pretrained,
  261. model,
  262. MODEL_URLS["SE_ResNeXt50_vd_32x4d"],
  263. use_ssld=use_ssld)
  264. return model
  265. def SENet154_vd(pretrained=False, use_ssld=False, **kwargs):
  266. model = ResNeXt(layers=152, cardinality=64, **kwargs)
  267. _load_pretrained(
  268. pretrained, model, MODEL_URLS["SENet154_vd"], use_ssld=use_ssld)
  269. return model