hrnet.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.paddleseg.cvlibs import manager, param_init
  19. from paddlex.paddleseg.models import layers
  20. from paddlex.paddleseg.utils import utils
  21. __all__ = [
  22. "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
  23. "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64"
  24. ]
  25. class HRNet(nn.Layer):
  26. """
  27. The HRNet implementation based on PaddlePaddle.
  28. The original article refers to
  29. Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
  30. (https://arxiv.org/pdf/1908.07919.pdf).
  31. Args:
  32. pretrained (str, optional): The path of pretrained model.
  33. stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
  34. stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
  35. stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
  36. stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
  37. stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
  38. stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
  39. stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
  40. stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
  41. stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
  42. stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
  43. stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
  44. stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
  45. has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
  46. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  47. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  48. """
  49. def __init__(self,
  50. pretrained=None,
  51. stage1_num_modules=1,
  52. stage1_num_blocks=(4, ),
  53. stage1_num_channels=(64, ),
  54. stage2_num_modules=1,
  55. stage2_num_blocks=(4, 4),
  56. stage2_num_channels=(18, 36),
  57. stage3_num_modules=4,
  58. stage3_num_blocks=(4, 4, 4),
  59. stage3_num_channels=(18, 36, 72),
  60. stage4_num_modules=3,
  61. stage4_num_blocks=(4, 4, 4, 4),
  62. stage4_num_channels=(18, 36, 72, 144),
  63. has_se=False,
  64. align_corners=False):
  65. super(HRNet, self).__init__()
  66. self.pretrained = pretrained
  67. self.stage1_num_modules = stage1_num_modules
  68. self.stage1_num_blocks = stage1_num_blocks
  69. self.stage1_num_channels = stage1_num_channels
  70. self.stage2_num_modules = stage2_num_modules
  71. self.stage2_num_blocks = stage2_num_blocks
  72. self.stage2_num_channels = stage2_num_channels
  73. self.stage3_num_modules = stage3_num_modules
  74. self.stage3_num_blocks = stage3_num_blocks
  75. self.stage3_num_channels = stage3_num_channels
  76. self.stage4_num_modules = stage4_num_modules
  77. self.stage4_num_blocks = stage4_num_blocks
  78. self.stage4_num_channels = stage4_num_channels
  79. self.has_se = has_se
  80. self.align_corners = align_corners
  81. self.feat_channels = [sum(stage4_num_channels)]
  82. self.conv_layer1_1 = layers.ConvBNReLU(
  83. in_channels=3,
  84. out_channels=64,
  85. kernel_size=3,
  86. stride=2,
  87. padding='same',
  88. bias_attr=False)
  89. self.conv_layer1_2 = layers.ConvBNReLU(
  90. in_channels=64,
  91. out_channels=64,
  92. kernel_size=3,
  93. stride=2,
  94. padding='same',
  95. bias_attr=False)
  96. self.la1 = Layer1(
  97. num_channels=64,
  98. num_blocks=self.stage1_num_blocks[0],
  99. num_filters=self.stage1_num_channels[0],
  100. has_se=has_se,
  101. name="layer2")
  102. self.tr1 = TransitionLayer(
  103. in_channels=[self.stage1_num_channels[0] * 4],
  104. out_channels=self.stage2_num_channels,
  105. name="tr1")
  106. self.st2 = Stage(
  107. num_channels=self.stage2_num_channels,
  108. num_modules=self.stage2_num_modules,
  109. num_blocks=self.stage2_num_blocks,
  110. num_filters=self.stage2_num_channels,
  111. has_se=self.has_se,
  112. name="st2",
  113. align_corners=align_corners)
  114. self.tr2 = TransitionLayer(
  115. in_channels=self.stage2_num_channels,
  116. out_channels=self.stage3_num_channels,
  117. name="tr2")
  118. self.st3 = Stage(
  119. num_channels=self.stage3_num_channels,
  120. num_modules=self.stage3_num_modules,
  121. num_blocks=self.stage3_num_blocks,
  122. num_filters=self.stage3_num_channels,
  123. has_se=self.has_se,
  124. name="st3",
  125. align_corners=align_corners)
  126. self.tr3 = TransitionLayer(
  127. in_channels=self.stage3_num_channels,
  128. out_channels=self.stage4_num_channels,
  129. name="tr3")
  130. self.st4 = Stage(
  131. num_channels=self.stage4_num_channels,
  132. num_modules=self.stage4_num_modules,
  133. num_blocks=self.stage4_num_blocks,
  134. num_filters=self.stage4_num_channels,
  135. has_se=self.has_se,
  136. name="st4",
  137. align_corners=align_corners)
  138. self.init_weight()
  139. def forward(self, x):
  140. conv1 = self.conv_layer1_1(x)
  141. conv2 = self.conv_layer1_2(conv1)
  142. la1 = self.la1(conv2)
  143. tr1 = self.tr1([la1])
  144. st2 = self.st2(tr1)
  145. tr2 = self.tr2(st2)
  146. st3 = self.st3(tr2)
  147. tr3 = self.tr3(st3)
  148. st4 = self.st4(tr3)
  149. size = paddle.shape(st4[0])[2:]
  150. x1 = F.interpolate(
  151. st4[1], size, mode='bilinear', align_corners=self.align_corners)
  152. x2 = F.interpolate(
  153. st4[2], size, mode='bilinear', align_corners=self.align_corners)
  154. x3 = F.interpolate(
  155. st4[3], size, mode='bilinear', align_corners=self.align_corners)
  156. x = paddle.concat([st4[0], x1, x2, x3], axis=1)
  157. return [x]
  158. def init_weight(self):
  159. for layer in self.sublayers():
  160. if isinstance(layer, nn.Conv2D):
  161. param_init.normal_init(layer.weight, std=0.001)
  162. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  163. param_init.constant_init(layer.weight, value=1.0)
  164. param_init.constant_init(layer.bias, value=0.0)
  165. if self.pretrained is not None:
  166. utils.load_pretrained_model(self, self.pretrained)
  167. class Layer1(nn.Layer):
  168. def __init__(self,
  169. num_channels,
  170. num_filters,
  171. num_blocks,
  172. has_se=False,
  173. name=None):
  174. super(Layer1, self).__init__()
  175. self.bottleneck_block_list = []
  176. for i in range(num_blocks):
  177. bottleneck_block = self.add_sublayer(
  178. "bb_{}_{}".format(name, i + 1),
  179. BottleneckBlock(
  180. num_channels=num_channels if i == 0 else num_filters * 4,
  181. num_filters=num_filters,
  182. has_se=has_se,
  183. stride=1,
  184. downsample=True if i == 0 else False,
  185. name=name + '_' + str(i + 1)))
  186. self.bottleneck_block_list.append(bottleneck_block)
  187. def forward(self, x):
  188. conv = x
  189. for block_func in self.bottleneck_block_list:
  190. conv = block_func(conv)
  191. return conv
  192. class TransitionLayer(nn.Layer):
  193. def __init__(self, in_channels, out_channels, name=None):
  194. super(TransitionLayer, self).__init__()
  195. num_in = len(in_channels)
  196. num_out = len(out_channels)
  197. self.conv_bn_func_list = []
  198. for i in range(num_out):
  199. residual = None
  200. if i < num_in:
  201. if in_channels[i] != out_channels[i]:
  202. residual = self.add_sublayer(
  203. "transition_{}_layer_{}".format(name, i + 1),
  204. layers.ConvBNReLU(
  205. in_channels=in_channels[i],
  206. out_channels=out_channels[i],
  207. kernel_size=3,
  208. padding='same',
  209. bias_attr=False))
  210. else:
  211. residual = self.add_sublayer(
  212. "transition_{}_layer_{}".format(name, i + 1),
  213. layers.ConvBNReLU(
  214. in_channels=in_channels[-1],
  215. out_channels=out_channels[i],
  216. kernel_size=3,
  217. stride=2,
  218. padding='same',
  219. bias_attr=False))
  220. self.conv_bn_func_list.append(residual)
  221. def forward(self, x):
  222. outs = []
  223. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  224. if conv_bn_func is None:
  225. outs.append(x[idx])
  226. else:
  227. if idx < len(x):
  228. outs.append(conv_bn_func(x[idx]))
  229. else:
  230. outs.append(conv_bn_func(x[-1]))
  231. return outs
  232. class Branches(nn.Layer):
  233. def __init__(self,
  234. num_blocks,
  235. in_channels,
  236. out_channels,
  237. has_se=False,
  238. name=None):
  239. super(Branches, self).__init__()
  240. self.basic_block_list = []
  241. for i in range(len(out_channels)):
  242. self.basic_block_list.append([])
  243. for j in range(num_blocks[i]):
  244. in_ch = in_channels[i] if j == 0 else out_channels[i]
  245. basic_block_func = self.add_sublayer(
  246. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  247. BasicBlock(
  248. num_channels=in_ch,
  249. num_filters=out_channels[i],
  250. has_se=has_se,
  251. name=name + '_branch_layer_' + str(i + 1) + '_' +
  252. str(j + 1)))
  253. self.basic_block_list[i].append(basic_block_func)
  254. def forward(self, x):
  255. outs = []
  256. for idx, input in enumerate(x):
  257. conv = input
  258. for basic_block_func in self.basic_block_list[idx]:
  259. conv = basic_block_func(conv)
  260. outs.append(conv)
  261. return outs
  262. class BottleneckBlock(nn.Layer):
  263. def __init__(self,
  264. num_channels,
  265. num_filters,
  266. has_se,
  267. stride=1,
  268. downsample=False,
  269. name=None):
  270. super(BottleneckBlock, self).__init__()
  271. self.has_se = has_se
  272. self.downsample = downsample
  273. self.conv1 = layers.ConvBNReLU(
  274. in_channels=num_channels,
  275. out_channels=num_filters,
  276. kernel_size=1,
  277. padding='same',
  278. bias_attr=False)
  279. self.conv2 = layers.ConvBNReLU(
  280. in_channels=num_filters,
  281. out_channels=num_filters,
  282. kernel_size=3,
  283. stride=stride,
  284. padding='same',
  285. bias_attr=False)
  286. self.conv3 = layers.ConvBN(
  287. in_channels=num_filters,
  288. out_channels=num_filters * 4,
  289. kernel_size=1,
  290. padding='same',
  291. bias_attr=False)
  292. if self.downsample:
  293. self.conv_down = layers.ConvBN(
  294. in_channels=num_channels,
  295. out_channels=num_filters * 4,
  296. kernel_size=1,
  297. padding='same',
  298. bias_attr=False)
  299. if self.has_se:
  300. self.se = SELayer(
  301. num_channels=num_filters * 4,
  302. num_filters=num_filters * 4,
  303. reduction_ratio=16,
  304. name=name + '_fc')
  305. def forward(self, x):
  306. residual = x
  307. conv1 = self.conv1(x)
  308. conv2 = self.conv2(conv1)
  309. conv3 = self.conv3(conv2)
  310. if self.downsample:
  311. residual = self.conv_down(x)
  312. if self.has_se:
  313. conv3 = self.se(conv3)
  314. y = conv3 + residual
  315. y = F.relu(y)
  316. return y
  317. class BasicBlock(nn.Layer):
  318. def __init__(self,
  319. num_channels,
  320. num_filters,
  321. stride=1,
  322. has_se=False,
  323. downsample=False,
  324. name=None):
  325. super(BasicBlock, self).__init__()
  326. self.has_se = has_se
  327. self.downsample = downsample
  328. self.conv1 = layers.ConvBNReLU(
  329. in_channels=num_channels,
  330. out_channels=num_filters,
  331. kernel_size=3,
  332. stride=stride,
  333. padding='same',
  334. bias_attr=False)
  335. self.conv2 = layers.ConvBN(
  336. in_channels=num_filters,
  337. out_channels=num_filters,
  338. kernel_size=3,
  339. padding='same',
  340. bias_attr=False)
  341. if self.downsample:
  342. self.conv_down = layers.ConvBNReLU(
  343. in_channels=num_channels,
  344. out_channels=num_filters,
  345. kernel_size=1,
  346. padding='same',
  347. bias_attr=False)
  348. if self.has_se:
  349. self.se = SELayer(
  350. num_channels=num_filters,
  351. num_filters=num_filters,
  352. reduction_ratio=16,
  353. name=name + '_fc')
  354. def forward(self, x):
  355. residual = x
  356. conv1 = self.conv1(x)
  357. conv2 = self.conv2(conv1)
  358. if self.downsample:
  359. residual = self.conv_down(x)
  360. if self.has_se:
  361. conv2 = self.se(conv2)
  362. y = conv2 + residual
  363. y = F.relu(y)
  364. return y
  365. class SELayer(nn.Layer):
  366. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  367. super(SELayer, self).__init__()
  368. self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
  369. self._num_channels = num_channels
  370. med_ch = int(num_channels / reduction_ratio)
  371. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  372. self.squeeze = nn.Linear(
  373. num_channels,
  374. med_ch,
  375. weight_attr=paddle.ParamAttr(
  376. initializer=nn.initializer.Uniform(-stdv, stdv)))
  377. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  378. self.excitation = nn.Linear(
  379. med_ch,
  380. num_filters,
  381. weight_attr=paddle.ParamAttr(
  382. initializer=nn.initializer.Uniform(-stdv, stdv)))
  383. def forward(self, x):
  384. pool = self.pool2d_gap(x)
  385. pool = paddle.reshape(pool, shape=[-1, self._num_channels])
  386. squeeze = self.squeeze(pool)
  387. squeeze = F.relu(squeeze)
  388. excitation = self.excitation(squeeze)
  389. excitation = F.sigmoid(excitation)
  390. excitation = paddle.reshape(
  391. excitation, shape=[-1, self._num_channels, 1, 1])
  392. out = x * excitation
  393. return out
  394. class Stage(nn.Layer):
  395. def __init__(self,
  396. num_channels,
  397. num_modules,
  398. num_blocks,
  399. num_filters,
  400. has_se=False,
  401. multi_scale_output=True,
  402. name=None,
  403. align_corners=False):
  404. super(Stage, self).__init__()
  405. self._num_modules = num_modules
  406. self.stage_func_list = []
  407. for i in range(num_modules):
  408. if i == num_modules - 1 and not multi_scale_output:
  409. stage_func = self.add_sublayer(
  410. "stage_{}_{}".format(name, i + 1),
  411. HighResolutionModule(
  412. num_channels=num_channels,
  413. num_blocks=num_blocks,
  414. num_filters=num_filters,
  415. has_se=has_se,
  416. multi_scale_output=False,
  417. name=name + '_' + str(i + 1),
  418. align_corners=align_corners))
  419. else:
  420. stage_func = self.add_sublayer(
  421. "stage_{}_{}".format(name, i + 1),
  422. HighResolutionModule(
  423. num_channels=num_channels,
  424. num_blocks=num_blocks,
  425. num_filters=num_filters,
  426. has_se=has_se,
  427. name=name + '_' + str(i + 1),
  428. align_corners=align_corners))
  429. self.stage_func_list.append(stage_func)
  430. def forward(self, x):
  431. out = x
  432. for idx in range(self._num_modules):
  433. out = self.stage_func_list[idx](out)
  434. return out
  435. class HighResolutionModule(nn.Layer):
  436. def __init__(self,
  437. num_channels,
  438. num_blocks,
  439. num_filters,
  440. has_se=False,
  441. multi_scale_output=True,
  442. name=None,
  443. align_corners=False):
  444. super(HighResolutionModule, self).__init__()
  445. self.branches_func = Branches(
  446. num_blocks=num_blocks,
  447. in_channels=num_channels,
  448. out_channels=num_filters,
  449. has_se=has_se,
  450. name=name)
  451. self.fuse_func = FuseLayers(
  452. in_channels=num_filters,
  453. out_channels=num_filters,
  454. multi_scale_output=multi_scale_output,
  455. name=name,
  456. align_corners=align_corners)
  457. def forward(self, x):
  458. out = self.branches_func(x)
  459. out = self.fuse_func(out)
  460. return out
  461. class FuseLayers(nn.Layer):
  462. def __init__(self,
  463. in_channels,
  464. out_channels,
  465. multi_scale_output=True,
  466. name=None,
  467. align_corners=False):
  468. super(FuseLayers, self).__init__()
  469. self._actual_ch = len(in_channels) if multi_scale_output else 1
  470. self._in_channels = in_channels
  471. self.align_corners = align_corners
  472. self.residual_func_list = []
  473. for i in range(self._actual_ch):
  474. for j in range(len(in_channels)):
  475. if j > i:
  476. residual_func = self.add_sublayer(
  477. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  478. layers.ConvBN(
  479. in_channels=in_channels[j],
  480. out_channels=out_channels[i],
  481. kernel_size=1,
  482. padding='same',
  483. bias_attr=False))
  484. self.residual_func_list.append(residual_func)
  485. elif j < i:
  486. pre_num_filters = in_channels[j]
  487. for k in range(i - j):
  488. if k == i - j - 1:
  489. residual_func = self.add_sublayer(
  490. "residual_{}_layer_{}_{}_{}".format(
  491. name, i + 1, j + 1, k + 1),
  492. layers.ConvBN(
  493. in_channels=pre_num_filters,
  494. out_channels=out_channels[i],
  495. kernel_size=3,
  496. stride=2,
  497. padding='same',
  498. bias_attr=False))
  499. pre_num_filters = out_channels[i]
  500. else:
  501. residual_func = self.add_sublayer(
  502. "residual_{}_layer_{}_{}_{}".format(
  503. name, i + 1, j + 1, k + 1),
  504. layers.ConvBNReLU(
  505. in_channels=pre_num_filters,
  506. out_channels=out_channels[j],
  507. kernel_size=3,
  508. stride=2,
  509. padding='same',
  510. bias_attr=False))
  511. pre_num_filters = out_channels[j]
  512. self.residual_func_list.append(residual_func)
  513. def forward(self, x):
  514. outs = []
  515. residual_func_idx = 0
  516. for i in range(self._actual_ch):
  517. residual = x[i]
  518. residual_shape = paddle.shape(residual)[-2:]
  519. for j in range(len(self._in_channels)):
  520. if j > i:
  521. y = self.residual_func_list[residual_func_idx](x[j])
  522. residual_func_idx += 1
  523. y = F.interpolate(
  524. y,
  525. residual_shape,
  526. mode='bilinear',
  527. align_corners=self.align_corners)
  528. residual = residual + y
  529. elif j < i:
  530. y = x[j]
  531. for k in range(i - j):
  532. y = self.residual_func_list[residual_func_idx](y)
  533. residual_func_idx += 1
  534. residual = residual + y
  535. residual = F.relu(residual)
  536. outs.append(residual)
  537. return outs
  538. @manager.BACKBONES.add_component
  539. def HRNet_W18_Small_V1(**kwargs):
  540. model = HRNet(
  541. stage1_num_modules=1,
  542. stage1_num_blocks=[1],
  543. stage1_num_channels=[32],
  544. stage2_num_modules=1,
  545. stage2_num_blocks=[2, 2],
  546. stage2_num_channels=[16, 32],
  547. stage3_num_modules=1,
  548. stage3_num_blocks=[2, 2, 2],
  549. stage3_num_channels=[16, 32, 64],
  550. stage4_num_modules=1,
  551. stage4_num_blocks=[2, 2, 2, 2],
  552. stage4_num_channels=[16, 32, 64, 128],
  553. **kwargs)
  554. return model
  555. @manager.BACKBONES.add_component
  556. def HRNet_W18_Small_V2(**kwargs):
  557. model = HRNet(
  558. stage1_num_modules=1,
  559. stage1_num_blocks=[2],
  560. stage1_num_channels=[64],
  561. stage2_num_modules=1,
  562. stage2_num_blocks=[2, 2],
  563. stage2_num_channels=[18, 36],
  564. stage3_num_modules=3,
  565. stage3_num_blocks=[2, 2, 2],
  566. stage3_num_channels=[18, 36, 72],
  567. stage4_num_modules=2,
  568. stage4_num_blocks=[2, 2, 2, 2],
  569. stage4_num_channels=[18, 36, 72, 144],
  570. **kwargs)
  571. return model
  572. @manager.BACKBONES.add_component
  573. def HRNet_W18(**kwargs):
  574. model = HRNet(
  575. stage1_num_modules=1,
  576. stage1_num_blocks=[4],
  577. stage1_num_channels=[64],
  578. stage2_num_modules=1,
  579. stage2_num_blocks=[4, 4],
  580. stage2_num_channels=[18, 36],
  581. stage3_num_modules=4,
  582. stage3_num_blocks=[4, 4, 4],
  583. stage3_num_channels=[18, 36, 72],
  584. stage4_num_modules=3,
  585. stage4_num_blocks=[4, 4, 4, 4],
  586. stage4_num_channels=[18, 36, 72, 144],
  587. **kwargs)
  588. return model
  589. @manager.BACKBONES.add_component
  590. def HRNet_W30(**kwargs):
  591. model = HRNet(
  592. stage1_num_modules=1,
  593. stage1_num_blocks=[4],
  594. stage1_num_channels=[64],
  595. stage2_num_modules=1,
  596. stage2_num_blocks=[4, 4],
  597. stage2_num_channels=[30, 60],
  598. stage3_num_modules=4,
  599. stage3_num_blocks=[4, 4, 4],
  600. stage3_num_channels=[30, 60, 120],
  601. stage4_num_modules=3,
  602. stage4_num_blocks=[4, 4, 4, 4],
  603. stage4_num_channels=[30, 60, 120, 240],
  604. **kwargs)
  605. return model
  606. @manager.BACKBONES.add_component
  607. def HRNet_W32(**kwargs):
  608. model = HRNet(
  609. stage1_num_modules=1,
  610. stage1_num_blocks=[4],
  611. stage1_num_channels=[64],
  612. stage2_num_modules=1,
  613. stage2_num_blocks=[4, 4],
  614. stage2_num_channels=[32, 64],
  615. stage3_num_modules=4,
  616. stage3_num_blocks=[4, 4, 4],
  617. stage3_num_channels=[32, 64, 128],
  618. stage4_num_modules=3,
  619. stage4_num_blocks=[4, 4, 4, 4],
  620. stage4_num_channels=[32, 64, 128, 256],
  621. **kwargs)
  622. return model
  623. @manager.BACKBONES.add_component
  624. def HRNet_W40(**kwargs):
  625. model = HRNet(
  626. stage1_num_modules=1,
  627. stage1_num_blocks=[4],
  628. stage1_num_channels=[64],
  629. stage2_num_modules=1,
  630. stage2_num_blocks=[4, 4],
  631. stage2_num_channels=[40, 80],
  632. stage3_num_modules=4,
  633. stage3_num_blocks=[4, 4, 4],
  634. stage3_num_channels=[40, 80, 160],
  635. stage4_num_modules=3,
  636. stage4_num_blocks=[4, 4, 4, 4],
  637. stage4_num_channels=[40, 80, 160, 320],
  638. **kwargs)
  639. return model
  640. @manager.BACKBONES.add_component
  641. def HRNet_W44(**kwargs):
  642. model = HRNet(
  643. stage1_num_modules=1,
  644. stage1_num_blocks=[4],
  645. stage1_num_channels=[64],
  646. stage2_num_modules=1,
  647. stage2_num_blocks=[4, 4],
  648. stage2_num_channels=[44, 88],
  649. stage3_num_modules=4,
  650. stage3_num_blocks=[4, 4, 4],
  651. stage3_num_channels=[44, 88, 176],
  652. stage4_num_modules=3,
  653. stage4_num_blocks=[4, 4, 4, 4],
  654. stage4_num_channels=[44, 88, 176, 352],
  655. **kwargs)
  656. return model
  657. @manager.BACKBONES.add_component
  658. def HRNet_W48(**kwargs):
  659. model = HRNet(
  660. stage1_num_modules=1,
  661. stage1_num_blocks=[4],
  662. stage1_num_channels=[64],
  663. stage2_num_modules=1,
  664. stage2_num_blocks=[4, 4],
  665. stage2_num_channels=[48, 96],
  666. stage3_num_modules=4,
  667. stage3_num_blocks=[4, 4, 4],
  668. stage3_num_channels=[48, 96, 192],
  669. stage4_num_modules=3,
  670. stage4_num_blocks=[4, 4, 4, 4],
  671. stage4_num_channels=[48, 96, 192, 384],
  672. **kwargs)
  673. return model
  674. @manager.BACKBONES.add_component
  675. def HRNet_W60(**kwargs):
  676. model = HRNet(
  677. stage1_num_modules=1,
  678. stage1_num_blocks=[4],
  679. stage1_num_channels=[64],
  680. stage2_num_modules=1,
  681. stage2_num_blocks=[4, 4],
  682. stage2_num_channels=[60, 120],
  683. stage3_num_modules=4,
  684. stage3_num_blocks=[4, 4, 4],
  685. stage3_num_channels=[60, 120, 240],
  686. stage4_num_modules=3,
  687. stage4_num_blocks=[4, 4, 4, 4],
  688. stage4_num_channels=[60, 120, 240, 480],
  689. **kwargs)
  690. return model
  691. @manager.BACKBONES.add_component
  692. def HRNet_W64(**kwargs):
  693. model = HRNet(
  694. stage1_num_modules=1,
  695. stage1_num_blocks=[4],
  696. stage1_num_channels=[64],
  697. stage2_num_modules=1,
  698. stage2_num_blocks=[4, 4],
  699. stage2_num_channels=[64, 128],
  700. stage3_num_modules=4,
  701. stage3_num_blocks=[4, 4, 4],
  702. stage3_num_channels=[64, 128, 256],
  703. stage4_num_modules=3,
  704. stage4_num_blocks=[4, 4, 4, 4],
  705. stage4_num_channels=[64, 128, 256, 512],
  706. **kwargs)
  707. return model