resnext.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  27. MODEL_URLS = {
  28. "ResNeXt50_32x4d":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_32x4d_pretrained.pdparams",
  30. "ResNeXt50_64x4d":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt50_64x4d_pretrained.pdparams",
  32. "ResNeXt101_32x4d":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x4d_pretrained.pdparams",
  34. "ResNeXt101_64x4d":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_64x4d_pretrained.pdparams",
  36. "ResNeXt152_32x4d":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_32x4d_pretrained.pdparams",
  38. "ResNeXt152_64x4d":
  39. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt152_64x4d_pretrained.pdparams",
  40. }
  41. __all__ = list(MODEL_URLS.keys())
  42. class ConvBNLayer(nn.Layer):
  43. def __init__(self,
  44. num_channels,
  45. num_filters,
  46. filter_size,
  47. stride=1,
  48. groups=1,
  49. act=None,
  50. name=None,
  51. data_format="NCHW"):
  52. super(ConvBNLayer, self).__init__()
  53. self._conv = Conv2D(
  54. in_channels=num_channels,
  55. out_channels=num_filters,
  56. kernel_size=filter_size,
  57. stride=stride,
  58. padding=(filter_size - 1) // 2,
  59. groups=groups,
  60. weight_attr=ParamAttr(name=name + "_weights"),
  61. bias_attr=False,
  62. data_format=data_format)
  63. if name == "conv1":
  64. bn_name = "bn_" + name
  65. else:
  66. bn_name = "bn" + name[3:]
  67. self._batch_norm = BatchNorm(
  68. num_filters,
  69. act=act,
  70. param_attr=ParamAttr(name=bn_name + '_scale'),
  71. bias_attr=ParamAttr(bn_name + '_offset'),
  72. moving_mean_name=bn_name + '_mean',
  73. moving_variance_name=bn_name + '_variance',
  74. data_layout=data_format)
  75. def forward(self, inputs):
  76. y = self._conv(inputs)
  77. y = self._batch_norm(y)
  78. return y
  79. class BottleneckBlock(nn.Layer):
  80. def __init__(self,
  81. num_channels,
  82. num_filters,
  83. stride,
  84. cardinality,
  85. shortcut=True,
  86. name=None,
  87. data_format="NCHW"):
  88. super(BottleneckBlock, self).__init__()
  89. self.conv0 = ConvBNLayer(
  90. num_channels=num_channels,
  91. num_filters=num_filters,
  92. filter_size=1,
  93. act='relu',
  94. name=name + "_branch2a",
  95. data_format=data_format)
  96. self.conv1 = ConvBNLayer(
  97. num_channels=num_filters,
  98. num_filters=num_filters,
  99. filter_size=3,
  100. groups=cardinality,
  101. stride=stride,
  102. act='relu',
  103. name=name + "_branch2b",
  104. data_format=data_format)
  105. self.conv2 = ConvBNLayer(
  106. num_channels=num_filters,
  107. num_filters=num_filters * 2 if cardinality == 32 else num_filters,
  108. filter_size=1,
  109. act=None,
  110. name=name + "_branch2c",
  111. data_format=data_format)
  112. if not shortcut:
  113. self.short = ConvBNLayer(
  114. num_channels=num_channels,
  115. num_filters=num_filters * 2
  116. if cardinality == 32 else num_filters,
  117. filter_size=1,
  118. stride=stride,
  119. name=name + "_branch1",
  120. data_format=data_format)
  121. self.shortcut = shortcut
  122. def forward(self, inputs):
  123. y = self.conv0(inputs)
  124. conv1 = self.conv1(y)
  125. conv2 = self.conv2(conv1)
  126. if self.shortcut:
  127. short = inputs
  128. else:
  129. short = self.short(inputs)
  130. y = paddle.add(x=short, y=conv2)
  131. y = F.relu(y)
  132. return y
  133. class ResNeXt(nn.Layer):
  134. def __init__(self,
  135. layers=50,
  136. class_num=1000,
  137. cardinality=32,
  138. input_image_channel=3,
  139. data_format="NCHW"):
  140. super(ResNeXt, self).__init__()
  141. self.layers = layers
  142. self.data_format = data_format
  143. self.input_image_channel = input_image_channel
  144. self.cardinality = cardinality
  145. supported_layers = [50, 101, 152]
  146. assert layers in supported_layers, \
  147. "supported layers are {} but input layer is {}".format(
  148. supported_layers, layers)
  149. supported_cardinality = [32, 64]
  150. assert cardinality in supported_cardinality, \
  151. "supported cardinality is {} but input cardinality is {}" \
  152. .format(supported_cardinality, cardinality)
  153. if layers == 50:
  154. depth = [3, 4, 6, 3]
  155. elif layers == 101:
  156. depth = [3, 4, 23, 3]
  157. elif layers == 152:
  158. depth = [3, 8, 36, 3]
  159. num_channels = [64, 256, 512, 1024]
  160. num_filters = [128, 256, 512,
  161. 1024] if cardinality == 32 else [256, 512, 1024, 2048]
  162. self.conv = ConvBNLayer(
  163. num_channels=self.input_image_channel,
  164. num_filters=64,
  165. filter_size=7,
  166. stride=2,
  167. act='relu',
  168. name="res_conv1",
  169. data_format=self.data_format)
  170. self.pool2d_max = MaxPool2D(
  171. kernel_size=3, stride=2, padding=1, data_format=self.data_format)
  172. self.block_list = []
  173. for block in range(len(depth)):
  174. shortcut = False
  175. for i in range(depth[block]):
  176. if layers in [101, 152] and block == 2:
  177. if i == 0:
  178. conv_name = "res" + str(block + 2) + "a"
  179. else:
  180. conv_name = "res" + str(block + 2) + "b" + str(i)
  181. else:
  182. conv_name = "res" + str(block + 2) + chr(97 + i)
  183. bottleneck_block = self.add_sublayer(
  184. 'bb_%d_%d' % (block, i),
  185. BottleneckBlock(
  186. num_channels=num_channels[block] if i == 0 else
  187. num_filters[block] * int(64 // self.cardinality),
  188. num_filters=num_filters[block],
  189. stride=2 if i == 0 and block != 0 else 1,
  190. cardinality=self.cardinality,
  191. shortcut=shortcut,
  192. name=conv_name,
  193. data_format=self.data_format))
  194. self.block_list.append(bottleneck_block)
  195. shortcut = True
  196. self.pool2d_avg = AdaptiveAvgPool2D(1, data_format=self.data_format)
  197. self.pool2d_avg_channels = num_channels[-1] * 2
  198. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  199. self.out = Linear(
  200. self.pool2d_avg_channels,
  201. class_num,
  202. weight_attr=ParamAttr(
  203. initializer=Uniform(-stdv, stdv), name="fc_weights"),
  204. bias_attr=ParamAttr(name="fc_offset"))
  205. def forward(self, inputs):
  206. with paddle.static.amp.fp16_guard():
  207. if self.data_format == "NHWC":
  208. inputs = paddle.tensor.transpose(inputs, [0, 2, 3, 1])
  209. inputs.stop_gradient = True
  210. y = self.conv(inputs)
  211. y = self.pool2d_max(y)
  212. for block in self.block_list:
  213. y = block(y)
  214. y = self.pool2d_avg(y)
  215. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  216. y = self.out(y)
  217. return y
  218. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  219. if pretrained is False:
  220. pass
  221. elif pretrained is True:
  222. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  223. elif isinstance(pretrained, str):
  224. load_dygraph_pretrain(model, pretrained)
  225. else:
  226. raise RuntimeError(
  227. "pretrained type is not available. Please use `string` or `boolean` type."
  228. )
  229. def ResNeXt50_32x4d(pretrained=False, use_ssld=False, **kwargs):
  230. model = ResNeXt(layers=50, cardinality=32, **kwargs)
  231. _load_pretrained(
  232. pretrained, model, MODEL_URLS["ResNeXt50_32x4d"], use_ssld=use_ssld)
  233. return model
  234. def ResNeXt50_64x4d(pretrained=False, use_ssld=False, **kwargs):
  235. model = ResNeXt(layers=50, cardinality=64, **kwargs)
  236. _load_pretrained(
  237. pretrained, model, MODEL_URLS["ResNeXt50_64x4d"], use_ssld=use_ssld)
  238. return model
  239. def ResNeXt101_32x4d(pretrained=False, use_ssld=False, **kwargs):
  240. model = ResNeXt(layers=101, cardinality=32, **kwargs)
  241. _load_pretrained(
  242. pretrained, model, MODEL_URLS["ResNeXt101_32x4d"], use_ssld=use_ssld)
  243. return model
  244. def ResNeXt101_64x4d(pretrained=False, use_ssld=False, **kwargs):
  245. model = ResNeXt(layers=101, cardinality=64, **kwargs)
  246. _load_pretrained(
  247. pretrained, model, MODEL_URLS["ResNeXt101_64x4d"], use_ssld=use_ssld)
  248. return model
  249. def ResNeXt152_32x4d(pretrained=False, use_ssld=False, **kwargs):
  250. model = ResNeXt(layers=152, cardinality=32, **kwargs)
  251. _load_pretrained(
  252. pretrained, model, MODEL_URLS["ResNeXt152_32x4d"], use_ssld=use_ssld)
  253. return model
  254. def ResNeXt152_64x4d(pretrained=False, use_ssld=False, **kwargs):
  255. model = ResNeXt(layers=152, cardinality=64, **kwargs)
  256. _load_pretrained(
  257. pretrained, model, MODEL_URLS["ResNeXt152_64x4d"], use_ssld=use_ssld)
  258. return model