dla.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/ucbdrive/dla
  15. import math
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn.initializer import Normal, Constant
  20. from paddlex.ppcls.arch.backbone.base.theseus_layer import Identity
  21. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  22. MODEL_URLS = {
  23. "DLA34":
  24. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA34_pretrained.pdparams",
  25. "DLA46_c":
  26. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA46_c_pretrained.pdparams",
  27. "DLA46x_c":
  28. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA46x_c_pretrained.pdparams",
  29. "DLA60":
  30. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA60_pretrained.pdparams",
  31. "DLA60x":
  32. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA60x_pretrained.pdparams",
  33. "DLA60x_c":
  34. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA60x_c_pretrained.pdparams",
  35. "DLA102":
  36. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA102_pretrained.pdparams",
  37. "DLA102x":
  38. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA102x_pretrained.pdparams",
  39. "DLA102x2":
  40. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA102x2_pretrained.pdparams",
  41. "DLA169":
  42. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DLA169_pretrained.pdparams"
  43. }
  44. __all__ = MODEL_URLS.keys()
  45. zeros_ = Constant(value=0.)
  46. ones_ = Constant(value=1.)
  47. class DlaBasic(nn.Layer):
  48. def __init__(self, inplanes, planes, stride=1, dilation=1, **cargs):
  49. super(DlaBasic, self).__init__()
  50. self.conv1 = nn.Conv2D(
  51. inplanes,
  52. planes,
  53. kernel_size=3,
  54. stride=stride,
  55. padding=dilation,
  56. bias_attr=False,
  57. dilation=dilation)
  58. self.bn1 = nn.BatchNorm2D(planes)
  59. self.relu = nn.ReLU()
  60. self.conv2 = nn.Conv2D(
  61. planes,
  62. planes,
  63. kernel_size=3,
  64. stride=1,
  65. padding=dilation,
  66. bias_attr=False,
  67. dilation=dilation)
  68. self.bn2 = nn.BatchNorm2D(planes)
  69. self.stride = stride
  70. def forward(self, x, residual=None):
  71. if residual is None:
  72. residual = x
  73. out = self.conv1(x)
  74. out = self.bn1(out)
  75. out = self.relu(out)
  76. out = self.conv2(out)
  77. out = self.bn2(out)
  78. out += residual
  79. out = self.relu(out)
  80. return out
  81. class DlaBottleneck(nn.Layer):
  82. expansion = 2
  83. def __init__(self,
  84. inplanes,
  85. outplanes,
  86. stride=1,
  87. dilation=1,
  88. cardinality=1,
  89. base_width=64):
  90. super(DlaBottleneck, self).__init__()
  91. self.stride = stride
  92. mid_planes = int(
  93. math.floor(outplanes * (base_width / 64)) * cardinality)
  94. mid_planes = mid_planes // self.expansion
  95. self.conv1 = nn.Conv2D(
  96. inplanes, mid_planes, kernel_size=1, bias_attr=False)
  97. self.bn1 = nn.BatchNorm2D(mid_planes)
  98. self.conv2 = nn.Conv2D(
  99. mid_planes,
  100. mid_planes,
  101. kernel_size=3,
  102. stride=stride,
  103. padding=dilation,
  104. bias_attr=False,
  105. dilation=dilation,
  106. groups=cardinality)
  107. self.bn2 = nn.BatchNorm2D(mid_planes)
  108. self.conv3 = nn.Conv2D(
  109. mid_planes, outplanes, kernel_size=1, bias_attr=False)
  110. self.bn3 = nn.BatchNorm2D(outplanes)
  111. self.relu = nn.ReLU()
  112. def forward(self, x, residual=None):
  113. if residual is None:
  114. residual = x
  115. out = self.conv1(x)
  116. out = self.bn1(out)
  117. out = self.relu(out)
  118. out = self.conv2(out)
  119. out = self.bn2(out)
  120. out = self.relu(out)
  121. out = self.conv3(out)
  122. out = self.bn3(out)
  123. out += residual
  124. out = self.relu(out)
  125. return out
  126. class DlaRoot(nn.Layer):
  127. def __init__(self, in_channels, out_channels, kernel_size, residual):
  128. super(DlaRoot, self).__init__()
  129. self.conv = nn.Conv2D(
  130. in_channels,
  131. out_channels,
  132. 1,
  133. stride=1,
  134. bias_attr=False,
  135. padding=(kernel_size - 1) // 2)
  136. self.bn = nn.BatchNorm2D(out_channels)
  137. self.relu = nn.ReLU()
  138. self.residual = residual
  139. def forward(self, *x):
  140. children = x
  141. x = self.conv(paddle.concat(x, 1))
  142. x = self.bn(x)
  143. if self.residual:
  144. x += children[0]
  145. x = self.relu(x)
  146. return x
  147. class DlaTree(nn.Layer):
  148. def __init__(self,
  149. levels,
  150. block,
  151. in_channels,
  152. out_channels,
  153. stride=1,
  154. dilation=1,
  155. cardinality=1,
  156. base_width=64,
  157. level_root=False,
  158. root_dim=0,
  159. root_kernel_size=1,
  160. root_residual=False):
  161. super(DlaTree, self).__init__()
  162. if root_dim == 0:
  163. root_dim = 2 * out_channels
  164. if level_root:
  165. root_dim += in_channels
  166. self.downsample = nn.MaxPool2D(
  167. stride, stride=stride) if stride > 1 else Identity()
  168. self.project = Identity()
  169. cargs = dict(
  170. dilation=dilation, cardinality=cardinality, base_width=base_width)
  171. if levels == 1:
  172. self.tree1 = block(in_channels, out_channels, stride, **cargs)
  173. self.tree2 = block(out_channels, out_channels, 1, **cargs)
  174. if in_channels != out_channels:
  175. self.project = nn.Sequential(
  176. nn.Conv2D(
  177. in_channels,
  178. out_channels,
  179. kernel_size=1,
  180. stride=1,
  181. bias_attr=False),
  182. nn.BatchNorm2D(out_channels))
  183. else:
  184. cargs.update(
  185. dict(
  186. root_kernel_size=root_kernel_size,
  187. root_residual=root_residual))
  188. self.tree1 = DlaTree(
  189. levels - 1,
  190. block,
  191. in_channels,
  192. out_channels,
  193. stride,
  194. root_dim=0,
  195. **cargs)
  196. self.tree2 = DlaTree(
  197. levels - 1,
  198. block,
  199. out_channels,
  200. out_channels,
  201. root_dim=root_dim + out_channels,
  202. **cargs)
  203. if levels == 1:
  204. self.root = DlaRoot(root_dim, out_channels, root_kernel_size,
  205. root_residual)
  206. self.level_root = level_root
  207. self.root_dim = root_dim
  208. self.levels = levels
  209. def forward(self, x, residual=None, children=None):
  210. children = [] if children is None else children
  211. bottom = self.downsample(x)
  212. residual = self.project(bottom)
  213. if self.level_root:
  214. children.append(bottom)
  215. x1 = self.tree1(x, residual)
  216. if self.levels == 1:
  217. x2 = self.tree2(x1)
  218. x = self.root(x2, x1, *children)
  219. else:
  220. children.append(x1)
  221. x = self.tree2(x1, children=children)
  222. return x
  223. class DLA(nn.Layer):
  224. def __init__(self,
  225. levels,
  226. channels,
  227. in_chans=3,
  228. cardinality=1,
  229. base_width=64,
  230. block=DlaBottleneck,
  231. residual_root=False,
  232. drop_rate=0.0,
  233. class_num=1000,
  234. with_pool=True):
  235. super(DLA, self).__init__()
  236. self.channels = channels
  237. self.class_num = class_num
  238. self.with_pool = with_pool
  239. self.cardinality = cardinality
  240. self.base_width = base_width
  241. self.drop_rate = drop_rate
  242. self.base_layer = nn.Sequential(
  243. nn.Conv2D(
  244. in_chans,
  245. channels[0],
  246. kernel_size=7,
  247. stride=1,
  248. padding=3,
  249. bias_attr=False),
  250. nn.BatchNorm2D(channels[0]),
  251. nn.ReLU())
  252. self.level0 = self._make_conv_level(channels[0], channels[0],
  253. levels[0])
  254. self.level1 = self._make_conv_level(
  255. channels[0], channels[1], levels[1], stride=2)
  256. cargs = dict(
  257. cardinality=cardinality,
  258. base_width=base_width,
  259. root_residual=residual_root)
  260. self.level2 = DlaTree(
  261. levels[2],
  262. block,
  263. channels[1],
  264. channels[2],
  265. 2,
  266. level_root=False,
  267. **cargs)
  268. self.level3 = DlaTree(
  269. levels[3],
  270. block,
  271. channels[2],
  272. channels[3],
  273. 2,
  274. level_root=True,
  275. **cargs)
  276. self.level4 = DlaTree(
  277. levels[4],
  278. block,
  279. channels[3],
  280. channels[4],
  281. 2,
  282. level_root=True,
  283. **cargs)
  284. self.level5 = DlaTree(
  285. levels[5],
  286. block,
  287. channels[4],
  288. channels[5],
  289. 2,
  290. level_root=True,
  291. **cargs)
  292. self.feature_info = [
  293. # rare to have a meaningful stride 1 level
  294. dict(
  295. num_chs=channels[0], reduction=1, module='level0'),
  296. dict(
  297. num_chs=channels[1], reduction=2, module='level1'),
  298. dict(
  299. num_chs=channels[2], reduction=4, module='level2'),
  300. dict(
  301. num_chs=channels[3], reduction=8, module='level3'),
  302. dict(
  303. num_chs=channels[4], reduction=16, module='level4'),
  304. dict(
  305. num_chs=channels[5], reduction=32, module='level5'),
  306. ]
  307. self.num_features = channels[-1]
  308. if with_pool:
  309. self.global_pool = nn.AdaptiveAvgPool2D(1)
  310. if class_num > 0:
  311. self.fc = nn.Conv2D(self.num_features, class_num, 1)
  312. for m in self.sublayers():
  313. if isinstance(m, nn.Conv2D):
  314. n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  315. normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
  316. normal_(m.weight)
  317. elif isinstance(m, nn.BatchNorm2D):
  318. ones_(m.weight)
  319. zeros_(m.bias)
  320. def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
  321. modules = []
  322. for i in range(convs):
  323. modules.extend([
  324. nn.Conv2D(
  325. inplanes,
  326. planes,
  327. kernel_size=3,
  328. stride=stride if i == 0 else 1,
  329. padding=dilation,
  330. bias_attr=False,
  331. dilation=dilation), nn.BatchNorm2D(planes), nn.ReLU()
  332. ])
  333. inplanes = planes
  334. return nn.Sequential(*modules)
  335. def forward_features(self, x):
  336. x = self.base_layer(x)
  337. x = self.level0(x)
  338. x = self.level1(x)
  339. x = self.level2(x)
  340. x = self.level3(x)
  341. x = self.level4(x)
  342. x = self.level5(x)
  343. return x
  344. def forward(self, x):
  345. x = self.forward_features(x)
  346. if self.with_pool:
  347. x = self.global_pool(x)
  348. if self.drop_rate > 0.:
  349. x = F.dropout(x, p=self.drop_rate, training=self.training)
  350. if self.class_num > 0:
  351. x = self.fc(x)
  352. x = x.flatten(1)
  353. return x
  354. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  355. if pretrained is False:
  356. pass
  357. elif pretrained is True:
  358. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  359. elif isinstance(pretrained, str):
  360. load_dygraph_pretrain(model, pretrained)
  361. else:
  362. raise RuntimeError(
  363. "pretrained type is not available. Please use `string` or `boolean` type."
  364. )
  365. def DLA34(pretrained=False, **kwargs):
  366. model = DLA(levels=(1, 1, 1, 2, 2, 1),
  367. channels=(16, 32, 64, 128, 256, 512),
  368. block=DlaBasic,
  369. **kwargs)
  370. _load_pretrained(pretrained, model, MODEL_URLS["DLA34"])
  371. return model
  372. def DLA46_c(pretrained=False, **kwargs):
  373. model = DLA(levels=(1, 1, 1, 2, 2, 1),
  374. channels=(16, 32, 64, 64, 128, 256),
  375. block=DlaBottleneck,
  376. **kwargs)
  377. _load_pretrained(pretrained, model, MODEL_URLS["DLA46_c"])
  378. return model
  379. def DLA46x_c(pretrained=False, **kwargs):
  380. model = DLA(levels=(1, 1, 1, 2, 2, 1),
  381. channels=(16, 32, 64, 64, 128, 256),
  382. block=DlaBottleneck,
  383. cardinality=32,
  384. base_width=4,
  385. **kwargs)
  386. _load_pretrained(pretrained, model, MODEL_URLS["DLA46x_c"])
  387. return model
  388. def DLA60(pretrained=False, **kwargs):
  389. model = DLA(levels=(1, 1, 1, 2, 3, 1),
  390. channels=(16, 32, 128, 256, 512, 1024),
  391. block=DlaBottleneck,
  392. **kwargs)
  393. _load_pretrained(pretrained, model, MODEL_URLS["DLA60"])
  394. return model
  395. def DLA60x(pretrained=False, **kwargs):
  396. model = DLA(levels=(1, 1, 1, 2, 3, 1),
  397. channels=(16, 32, 128, 256, 512, 1024),
  398. block=DlaBottleneck,
  399. cardinality=32,
  400. base_width=4,
  401. **kwargs)
  402. _load_pretrained(pretrained, model, MODEL_URLS["DLA60x"])
  403. return model
  404. def DLA60x_c(pretrained=False, **kwargs):
  405. model = DLA(levels=(1, 1, 1, 2, 3, 1),
  406. channels=(16, 32, 64, 64, 128, 256),
  407. block=DlaBottleneck,
  408. cardinality=32,
  409. base_width=4,
  410. **kwargs)
  411. _load_pretrained(pretrained, model, MODEL_URLS["DLA60x_c"])
  412. return model
  413. def DLA102(pretrained=False, **kwargs):
  414. model = DLA(levels=(1, 1, 1, 3, 4, 1),
  415. channels=(16, 32, 128, 256, 512, 1024),
  416. block=DlaBottleneck,
  417. residual_root=True,
  418. **kwargs)
  419. _load_pretrained(pretrained, model, MODEL_URLS["DLA102"])
  420. return model
  421. def DLA102x(pretrained=False, **kwargs):
  422. model = DLA(levels=(1, 1, 1, 3, 4, 1),
  423. channels=(16, 32, 128, 256, 512, 1024),
  424. block=DlaBottleneck,
  425. cardinality=32,
  426. base_width=4,
  427. residual_root=True,
  428. **kwargs)
  429. _load_pretrained(pretrained, model, MODEL_URLS["DLA102x"])
  430. return model
  431. def DLA102x2(pretrained=False, **kwargs):
  432. model = DLA(levels=(1, 1, 1, 3, 4, 1),
  433. channels=(16, 32, 128, 256, 512, 1024),
  434. block=DlaBottleneck,
  435. cardinality=64,
  436. base_width=4,
  437. residual_root=True,
  438. **kwargs)
  439. _load_pretrained(pretrained, model, MODEL_URLS["DLA102x2"])
  440. return model
  441. def DLA169(pretrained=False, **kwargs):
  442. model = DLA(levels=(1, 1, 2, 3, 5, 1),
  443. channels=(16, 32, 128, 256, 512, 1024),
  444. block=DlaBottleneck,
  445. residual_root=True,
  446. **kwargs)
  447. _load_pretrained(pretrained, model, MODEL_URLS["DLA169"])
  448. return model