hrnet.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.paddleseg.cvlibs import manager, param_init
  19. from paddlex.paddleseg.models import layers
  20. from paddlex.paddleseg.utils import utils
  21. __all__ = [
  22. "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
  23. "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60",
  24. "HRNet_W64"
  25. ]
  26. class HRNet(nn.Layer):
  27. """
  28. The HRNet implementation based on PaddlePaddle.
  29. The original article refers to
  30. Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
  31. (https://arxiv.org/pdf/1908.07919.pdf).
  32. Args:
  33. pretrained (str, optional): The path of pretrained model.
  34. stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
  35. stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
  36. stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
  37. stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
  38. stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
  39. stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
  40. stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
  41. stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
  42. stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
  43. stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
  44. stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
  45. stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
  46. has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
  47. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  48. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  49. """
  50. def __init__(self,
  51. pretrained=None,
  52. stage1_num_modules=1,
  53. stage1_num_blocks=(4, ),
  54. stage1_num_channels=(64, ),
  55. stage2_num_modules=1,
  56. stage2_num_blocks=(4, 4),
  57. stage2_num_channels=(18, 36),
  58. stage3_num_modules=4,
  59. stage3_num_blocks=(4, 4, 4),
  60. stage3_num_channels=(18, 36, 72),
  61. stage4_num_modules=3,
  62. stage4_num_blocks=(4, 4, 4, 4),
  63. stage4_num_channels=(18, 36, 72, 144),
  64. has_se=False,
  65. align_corners=False):
  66. super(HRNet, self).__init__()
  67. self.pretrained = pretrained
  68. self.stage1_num_modules = stage1_num_modules
  69. self.stage1_num_blocks = stage1_num_blocks
  70. self.stage1_num_channels = stage1_num_channels
  71. self.stage2_num_modules = stage2_num_modules
  72. self.stage2_num_blocks = stage2_num_blocks
  73. self.stage2_num_channels = stage2_num_channels
  74. self.stage3_num_modules = stage3_num_modules
  75. self.stage3_num_blocks = stage3_num_blocks
  76. self.stage3_num_channels = stage3_num_channels
  77. self.stage4_num_modules = stage4_num_modules
  78. self.stage4_num_blocks = stage4_num_blocks
  79. self.stage4_num_channels = stage4_num_channels
  80. self.has_se = has_se
  81. self.align_corners = align_corners
  82. self.feat_channels = [sum(stage4_num_channels)]
  83. self.conv_layer1_1 = layers.ConvBNReLU(
  84. in_channels=3,
  85. out_channels=64,
  86. kernel_size=3,
  87. stride=2,
  88. padding='same',
  89. bias_attr=False)
  90. self.conv_layer1_2 = layers.ConvBNReLU(
  91. in_channels=64,
  92. out_channels=64,
  93. kernel_size=3,
  94. stride=2,
  95. padding='same',
  96. bias_attr=False)
  97. self.la1 = Layer1(
  98. num_channels=64,
  99. num_blocks=self.stage1_num_blocks[0],
  100. num_filters=self.stage1_num_channels[0],
  101. has_se=has_se,
  102. name="layer2")
  103. self.tr1 = TransitionLayer(
  104. in_channels=[self.stage1_num_channels[0] * 4],
  105. out_channels=self.stage2_num_channels,
  106. name="tr1")
  107. self.st2 = Stage(
  108. num_channels=self.stage2_num_channels,
  109. num_modules=self.stage2_num_modules,
  110. num_blocks=self.stage2_num_blocks,
  111. num_filters=self.stage2_num_channels,
  112. has_se=self.has_se,
  113. name="st2",
  114. align_corners=align_corners)
  115. self.tr2 = TransitionLayer(
  116. in_channels=self.stage2_num_channels,
  117. out_channels=self.stage3_num_channels,
  118. name="tr2")
  119. self.st3 = Stage(
  120. num_channels=self.stage3_num_channels,
  121. num_modules=self.stage3_num_modules,
  122. num_blocks=self.stage3_num_blocks,
  123. num_filters=self.stage3_num_channels,
  124. has_se=self.has_se,
  125. name="st3",
  126. align_corners=align_corners)
  127. self.tr3 = TransitionLayer(
  128. in_channels=self.stage3_num_channels,
  129. out_channels=self.stage4_num_channels,
  130. name="tr3")
  131. self.st4 = Stage(
  132. num_channels=self.stage4_num_channels,
  133. num_modules=self.stage4_num_modules,
  134. num_blocks=self.stage4_num_blocks,
  135. num_filters=self.stage4_num_channels,
  136. has_se=self.has_se,
  137. name="st4",
  138. align_corners=align_corners)
  139. self.init_weight()
  140. def forward(self, x):
  141. conv1 = self.conv_layer1_1(x)
  142. conv2 = self.conv_layer1_2(conv1)
  143. la1 = self.la1(conv2)
  144. tr1 = self.tr1([la1])
  145. st2 = self.st2(tr1)
  146. tr2 = self.tr2(st2)
  147. st3 = self.st3(tr2)
  148. tr3 = self.tr3(st3)
  149. st4 = self.st4(tr3)
  150. size = paddle.shape(st4[0])[2:]
  151. x1 = F.interpolate(
  152. st4[1], size, mode='bilinear', align_corners=self.align_corners)
  153. x2 = F.interpolate(
  154. st4[2], size, mode='bilinear', align_corners=self.align_corners)
  155. x3 = F.interpolate(
  156. st4[3], size, mode='bilinear', align_corners=self.align_corners)
  157. x = paddle.concat([st4[0], x1, x2, x3], axis=1)
  158. return [x]
  159. def init_weight(self):
  160. for layer in self.sublayers():
  161. if isinstance(layer, nn.Conv2D):
  162. param_init.normal_init(layer.weight, std=0.001)
  163. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  164. param_init.constant_init(layer.weight, value=1.0)
  165. param_init.constant_init(layer.bias, value=0.0)
  166. if self.pretrained is not None:
  167. utils.load_pretrained_model(self, self.pretrained)
  168. class Layer1(nn.Layer):
  169. def __init__(self,
  170. num_channels,
  171. num_filters,
  172. num_blocks,
  173. has_se=False,
  174. name=None):
  175. super(Layer1, self).__init__()
  176. self.bottleneck_block_list = []
  177. for i in range(num_blocks):
  178. bottleneck_block = self.add_sublayer(
  179. "bb_{}_{}".format(name, i + 1),
  180. BottleneckBlock(
  181. num_channels=num_channels if i == 0 else num_filters * 4,
  182. num_filters=num_filters,
  183. has_se=has_se,
  184. stride=1,
  185. downsample=True if i == 0 else False,
  186. name=name + '_' + str(i + 1)))
  187. self.bottleneck_block_list.append(bottleneck_block)
  188. def forward(self, x):
  189. conv = x
  190. for block_func in self.bottleneck_block_list:
  191. conv = block_func(conv)
  192. return conv
  193. class TransitionLayer(nn.Layer):
  194. def __init__(self, in_channels, out_channels, name=None):
  195. super(TransitionLayer, self).__init__()
  196. num_in = len(in_channels)
  197. num_out = len(out_channels)
  198. self.conv_bn_func_list = []
  199. for i in range(num_out):
  200. residual = None
  201. if i < num_in:
  202. if in_channels[i] != out_channels[i]:
  203. residual = self.add_sublayer(
  204. "transition_{}_layer_{}".format(name, i + 1),
  205. layers.ConvBNReLU(
  206. in_channels=in_channels[i],
  207. out_channels=out_channels[i],
  208. kernel_size=3,
  209. padding='same',
  210. bias_attr=False))
  211. else:
  212. residual = self.add_sublayer(
  213. "transition_{}_layer_{}".format(name, i + 1),
  214. layers.ConvBNReLU(
  215. in_channels=in_channels[-1],
  216. out_channels=out_channels[i],
  217. kernel_size=3,
  218. stride=2,
  219. padding='same',
  220. bias_attr=False))
  221. self.conv_bn_func_list.append(residual)
  222. def forward(self, x):
  223. outs = []
  224. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  225. if conv_bn_func is None:
  226. outs.append(x[idx])
  227. else:
  228. if idx < len(x):
  229. outs.append(conv_bn_func(x[idx]))
  230. else:
  231. outs.append(conv_bn_func(x[-1]))
  232. return outs
  233. class Branches(nn.Layer):
  234. def __init__(self,
  235. num_blocks,
  236. in_channels,
  237. out_channels,
  238. has_se=False,
  239. name=None):
  240. super(Branches, self).__init__()
  241. self.basic_block_list = []
  242. for i in range(len(out_channels)):
  243. self.basic_block_list.append([])
  244. for j in range(num_blocks[i]):
  245. in_ch = in_channels[i] if j == 0 else out_channels[i]
  246. basic_block_func = self.add_sublayer(
  247. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  248. BasicBlock(
  249. num_channels=in_ch,
  250. num_filters=out_channels[i],
  251. has_se=has_se,
  252. name=name + '_branch_layer_' + str(i + 1) + '_' +
  253. str(j + 1)))
  254. self.basic_block_list[i].append(basic_block_func)
  255. def forward(self, x):
  256. outs = []
  257. for idx, input in enumerate(x):
  258. conv = input
  259. for basic_block_func in self.basic_block_list[idx]:
  260. conv = basic_block_func(conv)
  261. outs.append(conv)
  262. return outs
  263. class BottleneckBlock(nn.Layer):
  264. def __init__(self,
  265. num_channels,
  266. num_filters,
  267. has_se,
  268. stride=1,
  269. downsample=False,
  270. name=None):
  271. super(BottleneckBlock, self).__init__()
  272. self.has_se = has_se
  273. self.downsample = downsample
  274. self.conv1 = layers.ConvBNReLU(
  275. in_channels=num_channels,
  276. out_channels=num_filters,
  277. kernel_size=1,
  278. padding='same',
  279. bias_attr=False)
  280. self.conv2 = layers.ConvBNReLU(
  281. in_channels=num_filters,
  282. out_channels=num_filters,
  283. kernel_size=3,
  284. stride=stride,
  285. padding='same',
  286. bias_attr=False)
  287. self.conv3 = layers.ConvBN(
  288. in_channels=num_filters,
  289. out_channels=num_filters * 4,
  290. kernel_size=1,
  291. padding='same',
  292. bias_attr=False)
  293. if self.downsample:
  294. self.conv_down = layers.ConvBN(
  295. in_channels=num_channels,
  296. out_channels=num_filters * 4,
  297. kernel_size=1,
  298. padding='same',
  299. bias_attr=False)
  300. if self.has_se:
  301. self.se = SELayer(
  302. num_channels=num_filters * 4,
  303. num_filters=num_filters * 4,
  304. reduction_ratio=16,
  305. name=name + '_fc')
  306. def forward(self, x):
  307. residual = x
  308. conv1 = self.conv1(x)
  309. conv2 = self.conv2(conv1)
  310. conv3 = self.conv3(conv2)
  311. if self.downsample:
  312. residual = self.conv_down(x)
  313. if self.has_se:
  314. conv3 = self.se(conv3)
  315. y = conv3 + residual
  316. y = F.relu(y)
  317. return y
  318. class BasicBlock(nn.Layer):
  319. def __init__(self,
  320. num_channels,
  321. num_filters,
  322. stride=1,
  323. has_se=False,
  324. downsample=False,
  325. name=None):
  326. super(BasicBlock, self).__init__()
  327. self.has_se = has_se
  328. self.downsample = downsample
  329. self.conv1 = layers.ConvBNReLU(
  330. in_channels=num_channels,
  331. out_channels=num_filters,
  332. kernel_size=3,
  333. stride=stride,
  334. padding='same',
  335. bias_attr=False)
  336. self.conv2 = layers.ConvBN(
  337. in_channels=num_filters,
  338. out_channels=num_filters,
  339. kernel_size=3,
  340. padding='same',
  341. bias_attr=False)
  342. if self.downsample:
  343. self.conv_down = layers.ConvBNReLU(
  344. in_channels=num_channels,
  345. out_channels=num_filters,
  346. kernel_size=1,
  347. padding='same',
  348. bias_attr=False)
  349. if self.has_se:
  350. self.se = SELayer(
  351. num_channels=num_filters,
  352. num_filters=num_filters,
  353. reduction_ratio=16,
  354. name=name + '_fc')
  355. def forward(self, x):
  356. residual = x
  357. conv1 = self.conv1(x)
  358. conv2 = self.conv2(conv1)
  359. if self.downsample:
  360. residual = self.conv_down(x)
  361. if self.has_se:
  362. conv2 = self.se(conv2)
  363. y = conv2 + residual
  364. y = F.relu(y)
  365. return y
  366. class SELayer(nn.Layer):
  367. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  368. super(SELayer, self).__init__()
  369. self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
  370. self._num_channels = num_channels
  371. med_ch = int(num_channels / reduction_ratio)
  372. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  373. self.squeeze = nn.Linear(
  374. num_channels,
  375. med_ch,
  376. weight_attr=paddle.ParamAttr(
  377. initializer=nn.initializer.Uniform(-stdv, stdv)))
  378. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  379. self.excitation = nn.Linear(
  380. med_ch,
  381. num_filters,
  382. weight_attr=paddle.ParamAttr(
  383. initializer=nn.initializer.Uniform(-stdv, stdv)))
  384. def forward(self, x):
  385. pool = self.pool2d_gap(x)
  386. pool = paddle.reshape(pool, shape=[-1, self._num_channels])
  387. squeeze = self.squeeze(pool)
  388. squeeze = F.relu(squeeze)
  389. excitation = self.excitation(squeeze)
  390. excitation = F.sigmoid(excitation)
  391. excitation = paddle.reshape(
  392. excitation, shape=[-1, self._num_channels, 1, 1])
  393. out = x * excitation
  394. return out
  395. class Stage(nn.Layer):
  396. def __init__(self,
  397. num_channels,
  398. num_modules,
  399. num_blocks,
  400. num_filters,
  401. has_se=False,
  402. multi_scale_output=True,
  403. name=None,
  404. align_corners=False):
  405. super(Stage, self).__init__()
  406. self._num_modules = num_modules
  407. self.stage_func_list = []
  408. for i in range(num_modules):
  409. if i == num_modules - 1 and not multi_scale_output:
  410. stage_func = self.add_sublayer(
  411. "stage_{}_{}".format(name, i + 1),
  412. HighResolutionModule(
  413. num_channels=num_channels,
  414. num_blocks=num_blocks,
  415. num_filters=num_filters,
  416. has_se=has_se,
  417. multi_scale_output=False,
  418. name=name + '_' + str(i + 1),
  419. align_corners=align_corners))
  420. else:
  421. stage_func = self.add_sublayer(
  422. "stage_{}_{}".format(name, i + 1),
  423. HighResolutionModule(
  424. num_channels=num_channels,
  425. num_blocks=num_blocks,
  426. num_filters=num_filters,
  427. has_se=has_se,
  428. name=name + '_' + str(i + 1),
  429. align_corners=align_corners))
  430. self.stage_func_list.append(stage_func)
  431. def forward(self, x):
  432. out = x
  433. for idx in range(self._num_modules):
  434. out = self.stage_func_list[idx](out)
  435. return out
  436. class HighResolutionModule(nn.Layer):
  437. def __init__(self,
  438. num_channels,
  439. num_blocks,
  440. num_filters,
  441. has_se=False,
  442. multi_scale_output=True,
  443. name=None,
  444. align_corners=False):
  445. super(HighResolutionModule, self).__init__()
  446. self.branches_func = Branches(
  447. num_blocks=num_blocks,
  448. in_channels=num_channels,
  449. out_channels=num_filters,
  450. has_se=has_se,
  451. name=name)
  452. self.fuse_func = FuseLayers(
  453. in_channels=num_filters,
  454. out_channels=num_filters,
  455. multi_scale_output=multi_scale_output,
  456. name=name,
  457. align_corners=align_corners)
  458. def forward(self, x):
  459. out = self.branches_func(x)
  460. out = self.fuse_func(out)
  461. return out
  462. class FuseLayers(nn.Layer):
  463. def __init__(self,
  464. in_channels,
  465. out_channels,
  466. multi_scale_output=True,
  467. name=None,
  468. align_corners=False):
  469. super(FuseLayers, self).__init__()
  470. self._actual_ch = len(in_channels) if multi_scale_output else 1
  471. self._in_channels = in_channels
  472. self.align_corners = align_corners
  473. self.residual_func_list = []
  474. for i in range(self._actual_ch):
  475. for j in range(len(in_channels)):
  476. if j > i:
  477. residual_func = self.add_sublayer(
  478. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  479. layers.ConvBN(
  480. in_channels=in_channels[j],
  481. out_channels=out_channels[i],
  482. kernel_size=1,
  483. padding='same',
  484. bias_attr=False))
  485. self.residual_func_list.append(residual_func)
  486. elif j < i:
  487. pre_num_filters = in_channels[j]
  488. for k in range(i - j):
  489. if k == i - j - 1:
  490. residual_func = self.add_sublayer(
  491. "residual_{}_layer_{}_{}_{}".format(
  492. name, i + 1, j + 1, k + 1),
  493. layers.ConvBN(
  494. in_channels=pre_num_filters,
  495. out_channels=out_channels[i],
  496. kernel_size=3,
  497. stride=2,
  498. padding='same',
  499. bias_attr=False))
  500. pre_num_filters = out_channels[i]
  501. else:
  502. residual_func = self.add_sublayer(
  503. "residual_{}_layer_{}_{}_{}".format(
  504. name, i + 1, j + 1, k + 1),
  505. layers.ConvBNReLU(
  506. in_channels=pre_num_filters,
  507. out_channels=out_channels[j],
  508. kernel_size=3,
  509. stride=2,
  510. padding='same',
  511. bias_attr=False))
  512. pre_num_filters = out_channels[j]
  513. self.residual_func_list.append(residual_func)
  514. def forward(self, x):
  515. outs = []
  516. residual_func_idx = 0
  517. for i in range(self._actual_ch):
  518. residual = x[i]
  519. residual_shape = paddle.shape(residual)[-2:]
  520. for j in range(len(self._in_channels)):
  521. if j > i:
  522. y = self.residual_func_list[residual_func_idx](x[j])
  523. residual_func_idx += 1
  524. y = F.interpolate(
  525. y,
  526. residual_shape,
  527. mode='bilinear',
  528. align_corners=self.align_corners)
  529. residual = residual + y
  530. elif j < i:
  531. y = x[j]
  532. for k in range(i - j):
  533. y = self.residual_func_list[residual_func_idx](y)
  534. residual_func_idx += 1
  535. residual = residual + y
  536. residual = F.relu(residual)
  537. outs.append(residual)
  538. return outs
  539. @manager.BACKBONES.add_component
  540. def HRNet_W18_Small_V1(**kwargs):
  541. model = HRNet(
  542. stage1_num_modules=1,
  543. stage1_num_blocks=[1],
  544. stage1_num_channels=[32],
  545. stage2_num_modules=1,
  546. stage2_num_blocks=[2, 2],
  547. stage2_num_channels=[16, 32],
  548. stage3_num_modules=1,
  549. stage3_num_blocks=[2, 2, 2],
  550. stage3_num_channels=[16, 32, 64],
  551. stage4_num_modules=1,
  552. stage4_num_blocks=[2, 2, 2, 2],
  553. stage4_num_channels=[16, 32, 64, 128],
  554. **kwargs)
  555. return model
  556. @manager.BACKBONES.add_component
  557. def HRNet_W18_Small_V2(**kwargs):
  558. model = HRNet(
  559. stage1_num_modules=1,
  560. stage1_num_blocks=[2],
  561. stage1_num_channels=[64],
  562. stage2_num_modules=1,
  563. stage2_num_blocks=[2, 2],
  564. stage2_num_channels=[18, 36],
  565. stage3_num_modules=3,
  566. stage3_num_blocks=[2, 2, 2],
  567. stage3_num_channels=[18, 36, 72],
  568. stage4_num_modules=2,
  569. stage4_num_blocks=[2, 2, 2, 2],
  570. stage4_num_channels=[18, 36, 72, 144],
  571. **kwargs)
  572. return model
  573. @manager.BACKBONES.add_component
  574. def HRNet_W18(**kwargs):
  575. model = HRNet(
  576. stage1_num_modules=1,
  577. stage1_num_blocks=[4],
  578. stage1_num_channels=[64],
  579. stage2_num_modules=1,
  580. stage2_num_blocks=[4, 4],
  581. stage2_num_channels=[18, 36],
  582. stage3_num_modules=4,
  583. stage3_num_blocks=[4, 4, 4],
  584. stage3_num_channels=[18, 36, 72],
  585. stage4_num_modules=3,
  586. stage4_num_blocks=[4, 4, 4, 4],
  587. stage4_num_channels=[18, 36, 72, 144],
  588. **kwargs)
  589. return model
  590. @manager.BACKBONES.add_component
  591. def HRNet_W30(**kwargs):
  592. model = HRNet(
  593. stage1_num_modules=1,
  594. stage1_num_blocks=[4],
  595. stage1_num_channels=[64],
  596. stage2_num_modules=1,
  597. stage2_num_blocks=[4, 4],
  598. stage2_num_channels=[30, 60],
  599. stage3_num_modules=4,
  600. stage3_num_blocks=[4, 4, 4],
  601. stage3_num_channels=[30, 60, 120],
  602. stage4_num_modules=3,
  603. stage4_num_blocks=[4, 4, 4, 4],
  604. stage4_num_channels=[30, 60, 120, 240],
  605. **kwargs)
  606. return model
  607. @manager.BACKBONES.add_component
  608. def HRNet_W32(**kwargs):
  609. model = HRNet(
  610. stage1_num_modules=1,
  611. stage1_num_blocks=[4],
  612. stage1_num_channels=[64],
  613. stage2_num_modules=1,
  614. stage2_num_blocks=[4, 4],
  615. stage2_num_channels=[32, 64],
  616. stage3_num_modules=4,
  617. stage3_num_blocks=[4, 4, 4],
  618. stage3_num_channels=[32, 64, 128],
  619. stage4_num_modules=3,
  620. stage4_num_blocks=[4, 4, 4, 4],
  621. stage4_num_channels=[32, 64, 128, 256],
  622. **kwargs)
  623. return model
  624. @manager.BACKBONES.add_component
  625. def HRNet_W40(**kwargs):
  626. model = HRNet(
  627. stage1_num_modules=1,
  628. stage1_num_blocks=[4],
  629. stage1_num_channels=[64],
  630. stage2_num_modules=1,
  631. stage2_num_blocks=[4, 4],
  632. stage2_num_channels=[40, 80],
  633. stage3_num_modules=4,
  634. stage3_num_blocks=[4, 4, 4],
  635. stage3_num_channels=[40, 80, 160],
  636. stage4_num_modules=3,
  637. stage4_num_blocks=[4, 4, 4, 4],
  638. stage4_num_channels=[40, 80, 160, 320],
  639. **kwargs)
  640. return model
  641. @manager.BACKBONES.add_component
  642. def HRNet_W44(**kwargs):
  643. model = HRNet(
  644. stage1_num_modules=1,
  645. stage1_num_blocks=[4],
  646. stage1_num_channels=[64],
  647. stage2_num_modules=1,
  648. stage2_num_blocks=[4, 4],
  649. stage2_num_channels=[44, 88],
  650. stage3_num_modules=4,
  651. stage3_num_blocks=[4, 4, 4],
  652. stage3_num_channels=[44, 88, 176],
  653. stage4_num_modules=3,
  654. stage4_num_blocks=[4, 4, 4, 4],
  655. stage4_num_channels=[44, 88, 176, 352],
  656. **kwargs)
  657. return model
  658. @manager.BACKBONES.add_component
  659. def HRNet_W48(**kwargs):
  660. model = HRNet(
  661. stage1_num_modules=1,
  662. stage1_num_blocks=[4],
  663. stage1_num_channels=[64],
  664. stage2_num_modules=1,
  665. stage2_num_blocks=[4, 4],
  666. stage2_num_channels=[48, 96],
  667. stage3_num_modules=4,
  668. stage3_num_blocks=[4, 4, 4],
  669. stage3_num_channels=[48, 96, 192],
  670. stage4_num_modules=3,
  671. stage4_num_blocks=[4, 4, 4, 4],
  672. stage4_num_channels=[48, 96, 192, 384],
  673. **kwargs)
  674. return model
  675. @manager.BACKBONES.add_component
  676. def HRNet_W60(**kwargs):
  677. model = HRNet(
  678. stage1_num_modules=1,
  679. stage1_num_blocks=[4],
  680. stage1_num_channels=[64],
  681. stage2_num_modules=1,
  682. stage2_num_blocks=[4, 4],
  683. stage2_num_channels=[60, 120],
  684. stage3_num_modules=4,
  685. stage3_num_blocks=[4, 4, 4],
  686. stage3_num_channels=[60, 120, 240],
  687. stage4_num_modules=3,
  688. stage4_num_blocks=[4, 4, 4, 4],
  689. stage4_num_channels=[60, 120, 240, 480],
  690. **kwargs)
  691. return model
  692. @manager.BACKBONES.add_component
  693. def HRNet_W64(**kwargs):
  694. model = HRNet(
  695. stage1_num_modules=1,
  696. stage1_num_blocks=[4],
  697. stage1_num_channels=[64],
  698. stage2_num_modules=1,
  699. stage2_num_blocks=[4, 4],
  700. stage2_num_channels=[64, 128],
  701. stage3_num_modules=4,
  702. stage3_num_blocks=[4, 4, 4],
  703. stage3_num_channels=[64, 128, 256],
  704. stage4_num_modules=3,
  705. stage4_num_blocks=[4, 4, 4, 4],
  706. stage4_num_channels=[64, 128, 256, 512],
  707. **kwargs)
  708. return model