cspnet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was heavily based on https://github.com/rwightman/pytorch-image-models
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle import ParamAttr
  19. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  20. MODEL_URLS = {
  21. "CSPDarkNet53":
  22. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CSPDarkNet53_pretrained.pdparams"
  23. }
  24. MODEL_CFGS = {
  25. "CSPDarkNet53": dict(
  26. stem=dict(
  27. out_chs=32, kernel_size=3, stride=1, pool=''),
  28. stage=dict(
  29. out_chs=(64, 128, 256, 512, 1024),
  30. depth=(1, 2, 8, 8, 4),
  31. stride=(2, ) * 5,
  32. exp_ratio=(2., ) + (1., ) * 4,
  33. bottle_ratio=(0.5, ) + (1.0, ) * 4,
  34. block_ratio=(1., ) + (0.5, ) * 4,
  35. down_growth=True, ))
  36. }
  37. __all__ = ['CSPDarkNet53'
  38. ] # model_registry will add each entrypoint fn to this
  39. class ConvBnAct(nn.Layer):
  40. def __init__(self,
  41. input_channels,
  42. output_channels,
  43. kernel_size=1,
  44. stride=1,
  45. padding=None,
  46. dilation=1,
  47. groups=1,
  48. act_layer=nn.LeakyReLU,
  49. norm_layer=nn.BatchNorm2D):
  50. super().__init__()
  51. if padding is None:
  52. padding = (kernel_size - 1) // 2
  53. self.conv = nn.Conv2D(
  54. in_channels=input_channels,
  55. out_channels=output_channels,
  56. kernel_size=kernel_size,
  57. stride=stride,
  58. padding=padding,
  59. dilation=dilation,
  60. groups=groups,
  61. weight_attr=ParamAttr(),
  62. bias_attr=False)
  63. self.bn = norm_layer(num_features=output_channels)
  64. self.act = act_layer()
  65. def forward(self, inputs):
  66. x = self.conv(inputs)
  67. x = self.bn(x)
  68. if self.act is not None:
  69. x = self.act(x)
  70. return x
  71. def create_stem(in_chans=3,
  72. out_chs=32,
  73. kernel_size=3,
  74. stride=2,
  75. pool='',
  76. act_layer=None,
  77. norm_layer=None):
  78. stem = nn.Sequential()
  79. if not isinstance(out_chs, (tuple, list)):
  80. out_chs = [out_chs]
  81. assert len(out_chs)
  82. in_c = in_chans
  83. for i, out_c in enumerate(out_chs):
  84. conv_name = f'conv{i + 1}'
  85. stem.add_sublayer(
  86. conv_name,
  87. ConvBnAct(
  88. in_c,
  89. out_c,
  90. kernel_size,
  91. stride=stride if i == 0 else 1,
  92. act_layer=act_layer,
  93. norm_layer=norm_layer))
  94. in_c = out_c
  95. last_conv = conv_name
  96. if pool:
  97. stem.add_sublayer(
  98. 'pool', nn.MaxPool2D(
  99. kernel_size=3, stride=2, padding=1))
  100. return stem, dict(
  101. num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
  102. class DarkBlock(nn.Layer):
  103. def __init__(self,
  104. in_chs,
  105. out_chs,
  106. dilation=1,
  107. bottle_ratio=0.5,
  108. groups=1,
  109. act_layer=nn.ReLU,
  110. norm_layer=nn.BatchNorm2D,
  111. attn_layer=None,
  112. drop_block=None):
  113. super(DarkBlock, self).__init__()
  114. mid_chs = int(round(out_chs * bottle_ratio))
  115. ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
  116. self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
  117. self.conv2 = ConvBnAct(
  118. mid_chs,
  119. out_chs,
  120. kernel_size=3,
  121. dilation=dilation,
  122. groups=groups,
  123. **ckwargs)
  124. def forward(self, x):
  125. shortcut = x
  126. x = self.conv1(x)
  127. x = self.conv2(x)
  128. x = x + shortcut
  129. return x
  130. class CrossStage(nn.Layer):
  131. def __init__(self,
  132. in_chs,
  133. out_chs,
  134. stride,
  135. dilation,
  136. depth,
  137. block_ratio=1.,
  138. bottle_ratio=1.,
  139. exp_ratio=1.,
  140. groups=1,
  141. first_dilation=None,
  142. down_growth=False,
  143. cross_linear=False,
  144. block_dpr=None,
  145. block_fn=DarkBlock,
  146. **block_kwargs):
  147. super(CrossStage, self).__init__()
  148. first_dilation = first_dilation or dilation
  149. down_chs = out_chs if down_growth else in_chs
  150. exp_chs = int(round(out_chs * exp_ratio))
  151. block_out_chs = int(round(out_chs * block_ratio))
  152. conv_kwargs = dict(
  153. act_layer=block_kwargs.get('act_layer'),
  154. norm_layer=block_kwargs.get('norm_layer'))
  155. if stride != 1 or first_dilation != dilation:
  156. self.conv_down = ConvBnAct(
  157. in_chs,
  158. down_chs,
  159. kernel_size=3,
  160. stride=stride,
  161. dilation=first_dilation,
  162. groups=groups,
  163. **conv_kwargs)
  164. prev_chs = down_chs
  165. else:
  166. self.conv_down = None
  167. prev_chs = in_chs
  168. self.conv_exp = ConvBnAct(
  169. prev_chs, exp_chs, kernel_size=1, **conv_kwargs)
  170. prev_chs = exp_chs // 2 # output of conv_exp is always split in two
  171. self.blocks = nn.Sequential()
  172. for i in range(depth):
  173. self.blocks.add_sublayer(
  174. str(i),
  175. block_fn(prev_chs, block_out_chs, dilation, bottle_ratio,
  176. groups, **block_kwargs))
  177. prev_chs = block_out_chs
  178. # transition convs
  179. self.conv_transition_b = ConvBnAct(
  180. prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
  181. self.conv_transition = ConvBnAct(
  182. exp_chs, out_chs, kernel_size=1, **conv_kwargs)
  183. def forward(self, x):
  184. if self.conv_down is not None:
  185. x = self.conv_down(x)
  186. x = self.conv_exp(x)
  187. split = x.shape[1] // 2
  188. xs, xb = x[:, :split], x[:, split:]
  189. xb = self.blocks(xb)
  190. xb = self.conv_transition_b(xb)
  191. out = self.conv_transition(paddle.concat([xs, xb], axis=1))
  192. return out
  193. class DarkStage(nn.Layer):
  194. def __init__(self,
  195. in_chs,
  196. out_chs,
  197. stride,
  198. dilation,
  199. depth,
  200. block_ratio=1.,
  201. bottle_ratio=1.,
  202. groups=1,
  203. first_dilation=None,
  204. block_fn=DarkBlock,
  205. block_dpr=None,
  206. **block_kwargs):
  207. super().__init__()
  208. first_dilation = first_dilation or dilation
  209. self.conv_down = ConvBnAct(
  210. in_chs,
  211. out_chs,
  212. kernel_size=3,
  213. stride=stride,
  214. dilation=first_dilation,
  215. groups=groups,
  216. act_layer=block_kwargs.get('act_layer'),
  217. norm_layer=block_kwargs.get('norm_layer'))
  218. prev_chs = out_chs
  219. block_out_chs = int(round(out_chs * block_ratio))
  220. self.blocks = nn.Sequential()
  221. for i in range(depth):
  222. self.blocks.add_sublayer(
  223. str(i),
  224. block_fn(prev_chs, block_out_chs, dilation, bottle_ratio,
  225. groups, **block_kwargs))
  226. prev_chs = block_out_chs
  227. def forward(self, x):
  228. x = self.conv_down(x)
  229. x = self.blocks(x)
  230. return x
  231. def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32):
  232. # get per stage args for stage and containing blocks, calculate strides to meet target output_stride
  233. num_stages = len(cfg['depth'])
  234. if 'groups' not in cfg:
  235. cfg['groups'] = (1, ) * num_stages
  236. if 'down_growth' in cfg and not isinstance(cfg['down_growth'],
  237. (list, tuple)):
  238. cfg['down_growth'] = (cfg['down_growth'], ) * num_stages
  239. stage_strides = []
  240. stage_dilations = []
  241. stage_first_dilations = []
  242. dilation = 1
  243. for cfg_stride in cfg['stride']:
  244. stage_first_dilations.append(dilation)
  245. if curr_stride >= output_stride:
  246. dilation *= cfg_stride
  247. stride = 1
  248. else:
  249. stride = cfg_stride
  250. curr_stride *= stride
  251. stage_strides.append(stride)
  252. stage_dilations.append(dilation)
  253. cfg['stride'] = stage_strides
  254. cfg['dilation'] = stage_dilations
  255. cfg['first_dilation'] = stage_first_dilations
  256. stage_args = [
  257. dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())
  258. ]
  259. return stage_args
  260. class CSPNet(nn.Layer):
  261. def __init__(self,
  262. cfg,
  263. in_chans=3,
  264. class_num=1000,
  265. output_stride=32,
  266. global_pool='avg',
  267. drop_rate=0.,
  268. act_layer=nn.LeakyReLU,
  269. norm_layer=nn.BatchNorm2D,
  270. zero_init_last_bn=True,
  271. stage_fn=CrossStage,
  272. block_fn=DarkBlock):
  273. super().__init__()
  274. self.class_num = class_num
  275. self.drop_rate = drop_rate
  276. assert output_stride in (8, 16, 32)
  277. layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
  278. # Construct the stem
  279. self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'],
  280. **layer_args)
  281. self.feature_info = [stem_feat_info]
  282. prev_chs = stem_feat_info['num_chs']
  283. curr_stride = stem_feat_info[
  284. 'reduction'] # reduction does not include pool
  285. if cfg['stem']['pool']:
  286. curr_stride *= 2
  287. # Construct the stages
  288. per_stage_args = _cfg_to_stage_args(
  289. cfg['stage'], curr_stride=curr_stride, output_stride=output_stride)
  290. self.stages = nn.LayerList()
  291. for i, sa in enumerate(per_stage_args):
  292. self.stages.add_sublayer(
  293. str(i),
  294. stage_fn(
  295. prev_chs, **sa, **layer_args, block_fn=block_fn))
  296. prev_chs = sa['out_chs']
  297. curr_stride *= sa['stride']
  298. self.feature_info += [
  299. dict(
  300. num_chs=prev_chs,
  301. reduction=curr_stride,
  302. module=f'stages.{i}')
  303. ]
  304. # Construct the head
  305. self.num_features = prev_chs
  306. self.pool = nn.AdaptiveAvgPool2D(1)
  307. self.flatten = nn.Flatten(1)
  308. self.fc = nn.Linear(
  309. prev_chs,
  310. class_num,
  311. weight_attr=ParamAttr(),
  312. bias_attr=ParamAttr())
  313. def forward(self, x):
  314. x = self.stem(x)
  315. for stage in self.stages:
  316. x = stage(x)
  317. x = self.pool(x)
  318. x = self.flatten(x)
  319. x = self.fc(x)
  320. return x
  321. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  322. if pretrained is False:
  323. pass
  324. elif pretrained is True:
  325. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  326. elif isinstance(pretrained, str):
  327. load_dygraph_pretrain(model, pretrained)
  328. else:
  329. raise RuntimeError(
  330. "pretrained type is not available. Please use `string` or `boolean` type."
  331. )
  332. def CSPDarkNet53(pretrained=False, use_ssld=False, **kwargs):
  333. model = CSPNet(MODEL_CFGS["CSPDarkNet53"], block_fn=DarkBlock, **kwargs)
  334. _load_pretrained(
  335. pretrained, model, MODEL_URLS["CSPDarkNet53"], use_ssld=use_ssld)
  336. return model