csp_darknet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.regularizer import L2Decay
  19. from paddlex.ppdet.core.workspace import register, serializable
  20. from paddlex.ppdet.modeling.initializer import conv_init_
  21. from ..shape_spec import ShapeSpec
  22. __all__ = [
  23. 'CSPDarkNet', 'BaseConv', 'DWConv', 'BottleNeck', 'SPPLayer', 'SPPFLayer'
  24. ]
  25. class BaseConv(nn.Layer):
  26. def __init__(self,
  27. in_channels,
  28. out_channels,
  29. ksize,
  30. stride,
  31. groups=1,
  32. bias=False,
  33. act="silu"):
  34. super(BaseConv, self).__init__()
  35. self.conv = nn.Conv2D(
  36. in_channels,
  37. out_channels,
  38. kernel_size=ksize,
  39. stride=stride,
  40. padding=(ksize - 1) // 2,
  41. groups=groups,
  42. bias_attr=bias)
  43. self.bn = nn.BatchNorm2D(
  44. out_channels,
  45. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  46. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  47. self._init_weights()
  48. def _init_weights(self):
  49. conv_init_(self.conv)
  50. def forward(self, x):
  51. # use 'x * F.sigmoid(x)' replace 'silu'
  52. x = self.bn(self.conv(x))
  53. y = x * F.sigmoid(x)
  54. return y
  55. class DWConv(nn.Layer):
  56. """Depthwise Conv"""
  57. def __init__(self,
  58. in_channels,
  59. out_channels,
  60. ksize,
  61. stride=1,
  62. bias=False,
  63. act="silu"):
  64. super(DWConv, self).__init__()
  65. self.dw_conv = BaseConv(
  66. in_channels,
  67. in_channels,
  68. ksize=ksize,
  69. stride=stride,
  70. groups=in_channels,
  71. bias=bias,
  72. act=act)
  73. self.pw_conv = BaseConv(
  74. in_channels,
  75. out_channels,
  76. ksize=1,
  77. stride=1,
  78. groups=1,
  79. bias=bias,
  80. act=act)
  81. def forward(self, x):
  82. return self.pw_conv(self.dw_conv(x))
  83. class Focus(nn.Layer):
  84. """Focus width and height information into channel space, used in YOLOX."""
  85. def __init__(self,
  86. in_channels,
  87. out_channels,
  88. ksize=3,
  89. stride=1,
  90. bias=False,
  91. act="silu"):
  92. super(Focus, self).__init__()
  93. self.conv = BaseConv(
  94. in_channels * 4,
  95. out_channels,
  96. ksize=ksize,
  97. stride=stride,
  98. bias=bias,
  99. act=act)
  100. def forward(self, inputs):
  101. # inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2]
  102. top_left = inputs[:, :, 0::2, 0::2]
  103. top_right = inputs[:, :, 0::2, 1::2]
  104. bottom_left = inputs[:, :, 1::2, 0::2]
  105. bottom_right = inputs[:, :, 1::2, 1::2]
  106. outputs = paddle.concat(
  107. [top_left, bottom_left, top_right, bottom_right], 1)
  108. return self.conv(outputs)
  109. class BottleNeck(nn.Layer):
  110. def __init__(self,
  111. in_channels,
  112. out_channels,
  113. shortcut=True,
  114. expansion=0.5,
  115. depthwise=False,
  116. bias=False,
  117. act="silu"):
  118. super(BottleNeck, self).__init__()
  119. hidden_channels = int(out_channels * expansion)
  120. Conv = DWConv if depthwise else BaseConv
  121. self.conv1 = BaseConv(
  122. in_channels,
  123. hidden_channels,
  124. ksize=1,
  125. stride=1,
  126. bias=bias,
  127. act=act)
  128. self.conv2 = Conv(
  129. hidden_channels,
  130. out_channels,
  131. ksize=3,
  132. stride=1,
  133. bias=bias,
  134. act=act)
  135. self.add_shortcut = shortcut and in_channels == out_channels
  136. def forward(self, x):
  137. y = self.conv2(self.conv1(x))
  138. if self.add_shortcut:
  139. y = y + x
  140. return y
  141. class SPPLayer(nn.Layer):
  142. """Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX"""
  143. def __init__(self,
  144. in_channels,
  145. out_channels,
  146. kernel_sizes=(5, 9, 13),
  147. bias=False,
  148. act="silu"):
  149. super(SPPLayer, self).__init__()
  150. hidden_channels = in_channels // 2
  151. self.conv1 = BaseConv(
  152. in_channels,
  153. hidden_channels,
  154. ksize=1,
  155. stride=1,
  156. bias=bias,
  157. act=act)
  158. self.maxpoolings = nn.LayerList([
  159. nn.MaxPool2D(
  160. kernel_size=ks, stride=1, padding=ks // 2)
  161. for ks in kernel_sizes
  162. ])
  163. conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
  164. self.conv2 = BaseConv(
  165. conv2_channels,
  166. out_channels,
  167. ksize=1,
  168. stride=1,
  169. bias=bias,
  170. act=act)
  171. def forward(self, x):
  172. x = self.conv1(x)
  173. x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1)
  174. x = self.conv2(x)
  175. return x
  176. class SPPFLayer(nn.Layer):
  177. """ Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher,
  178. equivalent to SPP(k=(5, 9, 13))
  179. """
  180. def __init__(self,
  181. in_channels,
  182. out_channels,
  183. ksize=5,
  184. bias=False,
  185. act='silu'):
  186. super(SPPFLayer, self).__init__()
  187. hidden_channels = in_channels // 2
  188. self.conv1 = BaseConv(
  189. in_channels,
  190. hidden_channels,
  191. ksize=1,
  192. stride=1,
  193. bias=bias,
  194. act=act)
  195. self.maxpooling = nn.MaxPool2D(
  196. kernel_size=ksize, stride=1, padding=ksize // 2)
  197. conv2_channels = hidden_channels * 4
  198. self.conv2 = BaseConv(
  199. conv2_channels,
  200. out_channels,
  201. ksize=1,
  202. stride=1,
  203. bias=bias,
  204. act=act)
  205. def forward(self, x):
  206. x = self.conv1(x)
  207. y1 = self.maxpooling(x)
  208. y2 = self.maxpooling(y1)
  209. y3 = self.maxpooling(y2)
  210. concats = paddle.concat([x, y1, y2, y3], axis=1)
  211. out = self.conv2(concats)
  212. return out
  213. class CSPLayer(nn.Layer):
  214. """CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5"""
  215. def __init__(self,
  216. in_channels,
  217. out_channels,
  218. num_blocks=1,
  219. shortcut=True,
  220. expansion=0.5,
  221. depthwise=False,
  222. bias=False,
  223. act="silu"):
  224. super(CSPLayer, self).__init__()
  225. hidden_channels = int(out_channels * expansion)
  226. self.conv1 = BaseConv(
  227. in_channels,
  228. hidden_channels,
  229. ksize=1,
  230. stride=1,
  231. bias=bias,
  232. act=act)
  233. self.conv2 = BaseConv(
  234. in_channels,
  235. hidden_channels,
  236. ksize=1,
  237. stride=1,
  238. bias=bias,
  239. act=act)
  240. self.bottlenecks = nn.Sequential(* [
  241. BottleNeck(
  242. hidden_channels,
  243. hidden_channels,
  244. shortcut=shortcut,
  245. expansion=1.0,
  246. depthwise=depthwise,
  247. bias=bias,
  248. act=act) for _ in range(num_blocks)
  249. ])
  250. self.conv3 = BaseConv(
  251. hidden_channels * 2,
  252. out_channels,
  253. ksize=1,
  254. stride=1,
  255. bias=bias,
  256. act=act)
  257. def forward(self, x):
  258. x_1 = self.conv1(x)
  259. x_1 = self.bottlenecks(x_1)
  260. x_2 = self.conv2(x)
  261. x = paddle.concat([x_1, x_2], axis=1)
  262. x = self.conv3(x)
  263. return x
  264. @register
  265. @serializable
  266. class CSPDarkNet(nn.Layer):
  267. """
  268. CSPDarkNet backbone.
  269. Args:
  270. arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X,
  271. and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5.
  272. depth_mult (float): Depth multiplier, multiply number of channels in
  273. each layer, default as 1.0.
  274. width_mult (float): Width multiplier, multiply number of blocks in
  275. CSPLayer, default as 1.0.
  276. depthwise (bool): Whether to use depth-wise conv layer.
  277. act (str): Activation function type, default as 'silu'.
  278. return_idx (list): Index of stages whose feature maps are returned.
  279. """
  280. __shared__ = ['depth_mult', 'width_mult', 'act', 'trt']
  281. # in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
  282. # 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
  283. arch_settings = {
  284. 'X': [[64, 128, 3, True, False], [128, 256, 9, True, False],
  285. [256, 512, 9, True, False], [512, 1024, 3, False, True]],
  286. 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  287. [256, 512, 9, True, False], [512, 1024, 3, True, True]],
  288. 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  289. [256, 512, 9, True, False], [512, 768, 3, True, False],
  290. [768, 1024, 3, True, True]],
  291. }
  292. def __init__(self,
  293. arch='X',
  294. depth_mult=1.0,
  295. width_mult=1.0,
  296. depthwise=False,
  297. act='silu',
  298. trt=False,
  299. return_idx=[2, 3, 4]):
  300. super(CSPDarkNet, self).__init__()
  301. self.arch = arch
  302. self.return_idx = return_idx
  303. Conv = DWConv if depthwise else BaseConv
  304. arch_setting = self.arch_settings[arch]
  305. base_channels = int(arch_setting[0][0] * width_mult)
  306. # Note: differences between the latest YOLOv5 and the original YOLOX
  307. # 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX)
  308. # 2. use SPPF(in YOLOv5) or SPP(in YOLOX)
  309. # 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer
  310. # 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX
  311. if arch in ['P5', 'P6']:
  312. # in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size)
  313. self.stem = Conv(
  314. 3, base_channels, ksize=6, stride=2, bias=False, act=act)
  315. spp_kernal_sizes = 5
  316. elif arch in ['X']:
  317. # in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes)
  318. self.stem = Focus(
  319. 3, base_channels, ksize=3, stride=1, bias=False, act=act)
  320. spp_kernal_sizes = (5, 9, 13)
  321. else:
  322. raise AttributeError("Unsupported arch type: {}".format(arch))
  323. _out_channels = [base_channels]
  324. layers_num = 1
  325. self.csp_dark_blocks = []
  326. for i, (in_channels, out_channels, num_blocks, shortcut,
  327. use_spp) in enumerate(arch_setting):
  328. in_channels = int(in_channels * width_mult)
  329. out_channels = int(out_channels * width_mult)
  330. _out_channels.append(out_channels)
  331. num_blocks = max(round(num_blocks * depth_mult), 1)
  332. stage = []
  333. conv_layer = self.add_sublayer(
  334. 'layers{}.stage{}.conv_layer'.format(layers_num, i + 1),
  335. Conv(
  336. in_channels, out_channels, 3, 2, bias=False, act=act))
  337. stage.append(conv_layer)
  338. layers_num += 1
  339. if use_spp and arch in ['X']:
  340. # in YOLOX use SPPLayer
  341. spp_layer = self.add_sublayer(
  342. 'layers{}.stage{}.spp_layer'.format(layers_num, i + 1),
  343. SPPLayer(
  344. out_channels,
  345. out_channels,
  346. kernel_sizes=spp_kernal_sizes,
  347. bias=False,
  348. act=act))
  349. stage.append(spp_layer)
  350. layers_num += 1
  351. csp_layer = self.add_sublayer(
  352. 'layers{}.stage{}.csp_layer'.format(layers_num, i + 1),
  353. CSPLayer(
  354. out_channels,
  355. out_channels,
  356. num_blocks=num_blocks,
  357. shortcut=shortcut,
  358. depthwise=depthwise,
  359. bias=False,
  360. act=act))
  361. stage.append(csp_layer)
  362. layers_num += 1
  363. if use_spp and arch in ['P5', 'P6']:
  364. # in latest YOLOv5 use SPPFLayer instead of SPPLayer
  365. sppf_layer = self.add_sublayer(
  366. 'layers{}.stage{}.sppf_layer'.format(layers_num, i + 1),
  367. SPPFLayer(
  368. out_channels,
  369. out_channels,
  370. ksize=5,
  371. bias=False,
  372. act=act))
  373. stage.append(sppf_layer)
  374. layers_num += 1
  375. self.csp_dark_blocks.append(nn.Sequential(*stage))
  376. self._out_channels = [_out_channels[i] for i in self.return_idx]
  377. self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx]
  378. def forward(self, inputs):
  379. x = inputs['image']
  380. outputs = []
  381. x = self.stem(x)
  382. for i, layer in enumerate(self.csp_dark_blocks):
  383. x = layer(x)
  384. if i + 1 in self.return_idx:
  385. outputs.append(x)
  386. return outputs
  387. @property
  388. def out_shape(self):
  389. return [
  390. ShapeSpec(
  391. channels=c, stride=s)
  392. for c, s in zip(self._out_channels, self.strides)
  393. ]