resnest.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/zhanghang1989/ResNeSt
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import numpy as np
  19. import paddle
  20. import math
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from paddle import ParamAttr
  24. from paddle.nn.initializer import KaimingNormal
  25. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  26. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  27. from paddle.regularizer import L2Decay
  28. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  29. MODEL_URLS = {
  30. "ResNeSt50_fast_1s1x64d":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeSt50_fast_1s1x64d_pretrained.pdparams",
  32. "ResNeSt50":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeSt50_pretrained.pdparams",
  34. "ResNeSt101":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeSt101_pretrained.pdparams",
  36. }
  37. __all__ = list(MODEL_URLS.keys())
  38. class ConvBNLayer(nn.Layer):
  39. def __init__(self,
  40. num_channels,
  41. num_filters,
  42. filter_size,
  43. stride=1,
  44. dilation=1,
  45. groups=1,
  46. act=None,
  47. name=None):
  48. super(ConvBNLayer, self).__init__()
  49. bn_decay = 0.0
  50. self._conv = Conv2D(
  51. in_channels=num_channels,
  52. out_channels=num_filters,
  53. kernel_size=filter_size,
  54. stride=stride,
  55. padding=(filter_size - 1) // 2,
  56. dilation=dilation,
  57. groups=groups,
  58. weight_attr=ParamAttr(name=name + "_weight"),
  59. bias_attr=False)
  60. self._batch_norm = BatchNorm(
  61. num_filters,
  62. act=act,
  63. param_attr=ParamAttr(
  64. name=name + "_scale", regularizer=L2Decay(bn_decay)),
  65. bias_attr=ParamAttr(
  66. name + "_offset", regularizer=L2Decay(bn_decay)),
  67. moving_mean_name=name + "_mean",
  68. moving_variance_name=name + "_variance")
  69. def forward(self, x):
  70. x = self._conv(x)
  71. x = self._batch_norm(x)
  72. return x
  73. class rSoftmax(nn.Layer):
  74. def __init__(self, radix, cardinality):
  75. super(rSoftmax, self).__init__()
  76. self.radix = radix
  77. self.cardinality = cardinality
  78. def forward(self, x):
  79. cardinality = self.cardinality
  80. radix = self.radix
  81. batch, r, h, w = x.shape
  82. if self.radix > 1:
  83. x = paddle.reshape(
  84. x=x,
  85. shape=[
  86. batch, cardinality, radix,
  87. int(r * h * w / cardinality / radix)
  88. ])
  89. x = paddle.transpose(x=x, perm=[0, 2, 1, 3])
  90. x = nn.functional.softmax(x, axis=1)
  91. x = paddle.reshape(x=x, shape=[batch, r * h * w, 1, 1])
  92. else:
  93. x = nn.functional.sigmoid(x)
  94. return x
  95. class SplatConv(nn.Layer):
  96. def __init__(self,
  97. in_channels,
  98. channels,
  99. kernel_size,
  100. stride=1,
  101. padding=0,
  102. dilation=1,
  103. groups=1,
  104. bias=True,
  105. radix=2,
  106. reduction_factor=4,
  107. rectify_avg=False,
  108. name=None):
  109. super(SplatConv, self).__init__()
  110. self.radix = radix
  111. self.conv1 = ConvBNLayer(
  112. num_channels=in_channels,
  113. num_filters=channels * radix,
  114. filter_size=kernel_size,
  115. stride=stride,
  116. groups=groups * radix,
  117. act="relu",
  118. name=name + "_1_weights")
  119. self.avg_pool2d = AdaptiveAvgPool2D(1)
  120. inter_channels = int(max(in_channels * radix // reduction_factor, 32))
  121. # to calc gap
  122. self.conv2 = ConvBNLayer(
  123. num_channels=channels,
  124. num_filters=inter_channels,
  125. filter_size=1,
  126. stride=1,
  127. groups=groups,
  128. act="relu",
  129. name=name + "_2_weights")
  130. # to calc atten
  131. self.conv3 = Conv2D(
  132. in_channels=inter_channels,
  133. out_channels=channels * radix,
  134. kernel_size=1,
  135. stride=1,
  136. padding=0,
  137. groups=groups,
  138. weight_attr=ParamAttr(
  139. name=name + "_weights", initializer=KaimingNormal()),
  140. bias_attr=False)
  141. self.rsoftmax = rSoftmax(radix=radix, cardinality=groups)
  142. def forward(self, x):
  143. x = self.conv1(x)
  144. if self.radix > 1:
  145. splited = paddle.split(x, num_or_sections=self.radix, axis=1)
  146. gap = paddle.add_n(splited)
  147. else:
  148. gap = x
  149. gap = self.avg_pool2d(gap)
  150. gap = self.conv2(gap)
  151. atten = self.conv3(gap)
  152. atten = self.rsoftmax(atten)
  153. if self.radix > 1:
  154. attens = paddle.split(atten, num_or_sections=self.radix, axis=1)
  155. y = paddle.add_n([
  156. paddle.multiply(split, att)
  157. for (att, split) in zip(attens, splited)
  158. ])
  159. else:
  160. y = paddle.multiply(x, atten)
  161. return y
  162. class BottleneckBlock(nn.Layer):
  163. def __init__(self,
  164. inplanes,
  165. planes,
  166. stride=1,
  167. radix=1,
  168. cardinality=1,
  169. bottleneck_width=64,
  170. avd=False,
  171. avd_first=False,
  172. dilation=1,
  173. is_first=False,
  174. rectify_avg=False,
  175. last_gamma=False,
  176. avg_down=False,
  177. name=None):
  178. super(BottleneckBlock, self).__init__()
  179. self.inplanes = inplanes
  180. self.planes = planes
  181. self.stride = stride
  182. self.radix = radix
  183. self.cardinality = cardinality
  184. self.avd = avd
  185. self.avd_first = avd_first
  186. self.dilation = dilation
  187. self.is_first = is_first
  188. self.rectify_avg = rectify_avg
  189. self.last_gamma = last_gamma
  190. self.avg_down = avg_down
  191. group_width = int(planes * (bottleneck_width / 64.)) * cardinality
  192. self.conv1 = ConvBNLayer(
  193. num_channels=self.inplanes,
  194. num_filters=group_width,
  195. filter_size=1,
  196. stride=1,
  197. groups=1,
  198. act="relu",
  199. name=name + "_conv1")
  200. if avd and avd_first and (stride > 1 or is_first):
  201. self.avg_pool2d_1 = AvgPool2D(
  202. kernel_size=3, stride=stride, padding=1)
  203. if radix >= 1:
  204. self.conv2 = SplatConv(
  205. in_channels=group_width,
  206. channels=group_width,
  207. kernel_size=3,
  208. stride=1,
  209. padding=dilation,
  210. dilation=dilation,
  211. groups=cardinality,
  212. bias=False,
  213. radix=radix,
  214. rectify_avg=rectify_avg,
  215. name=name + "_splat")
  216. else:
  217. self.conv2 = ConvBNLayer(
  218. num_channels=group_width,
  219. num_filters=group_width,
  220. filter_size=3,
  221. stride=1,
  222. dilation=dilation,
  223. groups=cardinality,
  224. act="relu",
  225. name=name + "_conv2")
  226. if avd and avd_first == False and (stride > 1 or is_first):
  227. self.avg_pool2d_2 = AvgPool2D(
  228. kernel_size=3, stride=stride, padding=1)
  229. self.conv3 = ConvBNLayer(
  230. num_channels=group_width,
  231. num_filters=planes * 4,
  232. filter_size=1,
  233. stride=1,
  234. groups=1,
  235. act=None,
  236. name=name + "_conv3")
  237. if stride != 1 or self.inplanes != self.planes * 4:
  238. if avg_down:
  239. if dilation == 1:
  240. self.avg_pool2d_3 = AvgPool2D(
  241. kernel_size=stride, stride=stride, padding=0)
  242. else:
  243. self.avg_pool2d_3 = AvgPool2D(
  244. kernel_size=1, stride=1, padding=0, ceil_mode=True)
  245. self.conv4 = Conv2D(
  246. in_channels=self.inplanes,
  247. out_channels=planes * 4,
  248. kernel_size=1,
  249. stride=1,
  250. padding=0,
  251. groups=1,
  252. weight_attr=ParamAttr(
  253. name=name + "_weights", initializer=KaimingNormal()),
  254. bias_attr=False)
  255. else:
  256. self.conv4 = Conv2D(
  257. in_channels=self.inplanes,
  258. out_channels=planes * 4,
  259. kernel_size=1,
  260. stride=stride,
  261. padding=0,
  262. groups=1,
  263. weight_attr=ParamAttr(
  264. name=name + "_shortcut_weights",
  265. initializer=KaimingNormal()),
  266. bias_attr=False)
  267. bn_decay = 0.0
  268. self._batch_norm = BatchNorm(
  269. planes * 4,
  270. act=None,
  271. param_attr=ParamAttr(
  272. name=name + "_shortcut_scale",
  273. regularizer=L2Decay(bn_decay)),
  274. bias_attr=ParamAttr(
  275. name + "_shortcut_offset", regularizer=L2Decay(bn_decay)),
  276. moving_mean_name=name + "_shortcut_mean",
  277. moving_variance_name=name + "_shortcut_variance")
  278. def forward(self, x):
  279. short = x
  280. x = self.conv1(x)
  281. if self.avd and self.avd_first and (self.stride > 1 or self.is_first):
  282. x = self.avg_pool2d_1(x)
  283. x = self.conv2(x)
  284. if self.avd and self.avd_first == False and (self.stride > 1 or
  285. self.is_first):
  286. x = self.avg_pool2d_2(x)
  287. x = self.conv3(x)
  288. if self.stride != 1 or self.inplanes != self.planes * 4:
  289. if self.avg_down:
  290. short = self.avg_pool2d_3(short)
  291. short = self.conv4(short)
  292. short = self._batch_norm(short)
  293. y = paddle.add(x=short, y=x)
  294. y = F.relu(y)
  295. return y
  296. class ResNeStLayer(nn.Layer):
  297. def __init__(self,
  298. inplanes,
  299. planes,
  300. blocks,
  301. radix,
  302. cardinality,
  303. bottleneck_width,
  304. avg_down,
  305. avd,
  306. avd_first,
  307. rectify_avg,
  308. last_gamma,
  309. stride=1,
  310. dilation=1,
  311. is_first=True,
  312. name=None):
  313. super(ResNeStLayer, self).__init__()
  314. self.inplanes = inplanes
  315. self.planes = planes
  316. self.blocks = blocks
  317. self.radix = radix
  318. self.cardinality = cardinality
  319. self.bottleneck_width = bottleneck_width
  320. self.avg_down = avg_down
  321. self.avd = avd
  322. self.avd_first = avd_first
  323. self.rectify_avg = rectify_avg
  324. self.last_gamma = last_gamma
  325. self.is_first = is_first
  326. if dilation == 1 or dilation == 2:
  327. bottleneck_func = self.add_sublayer(
  328. name + "_bottleneck_0",
  329. BottleneckBlock(
  330. inplanes=self.inplanes,
  331. planes=planes,
  332. stride=stride,
  333. radix=radix,
  334. cardinality=cardinality,
  335. bottleneck_width=bottleneck_width,
  336. avg_down=self.avg_down,
  337. avd=avd,
  338. avd_first=avd_first,
  339. dilation=1,
  340. is_first=is_first,
  341. rectify_avg=rectify_avg,
  342. last_gamma=last_gamma,
  343. name=name + "_bottleneck_0"))
  344. elif dilation == 4:
  345. bottleneck_func = self.add_sublayer(
  346. name + "_bottleneck_0",
  347. BottleneckBlock(
  348. inplanes=self.inplanes,
  349. planes=planes,
  350. stride=stride,
  351. radix=radix,
  352. cardinality=cardinality,
  353. bottleneck_width=bottleneck_width,
  354. avg_down=self.avg_down,
  355. avd=avd,
  356. avd_first=avd_first,
  357. dilation=2,
  358. is_first=is_first,
  359. rectify_avg=rectify_avg,
  360. last_gamma=last_gamma,
  361. name=name + "_bottleneck_0"))
  362. else:
  363. raise RuntimeError("=>unknown dilation size")
  364. self.inplanes = planes * 4
  365. self.bottleneck_block_list = [bottleneck_func]
  366. for i in range(1, blocks):
  367. curr_name = name + "_bottleneck_" + str(i)
  368. bottleneck_func = self.add_sublayer(
  369. curr_name,
  370. BottleneckBlock(
  371. inplanes=self.inplanes,
  372. planes=planes,
  373. radix=radix,
  374. cardinality=cardinality,
  375. bottleneck_width=bottleneck_width,
  376. avg_down=self.avg_down,
  377. avd=avd,
  378. avd_first=avd_first,
  379. dilation=dilation,
  380. rectify_avg=rectify_avg,
  381. last_gamma=last_gamma,
  382. name=curr_name))
  383. self.bottleneck_block_list.append(bottleneck_func)
  384. def forward(self, x):
  385. for bottleneck_block in self.bottleneck_block_list:
  386. x = bottleneck_block(x)
  387. return x
  388. class ResNeSt(nn.Layer):
  389. def __init__(self,
  390. layers,
  391. radix=1,
  392. groups=1,
  393. bottleneck_width=64,
  394. dilated=False,
  395. dilation=1,
  396. deep_stem=False,
  397. stem_width=64,
  398. avg_down=False,
  399. rectify_avg=False,
  400. avd=False,
  401. avd_first=False,
  402. final_drop=0.0,
  403. last_gamma=False,
  404. class_num=1000):
  405. super(ResNeSt, self).__init__()
  406. self.cardinality = groups
  407. self.bottleneck_width = bottleneck_width
  408. # ResNet-D params
  409. self.inplanes = stem_width * 2 if deep_stem else 64
  410. self.avg_down = avg_down
  411. self.last_gamma = last_gamma
  412. # ResNeSt params
  413. self.radix = radix
  414. self.avd = avd
  415. self.avd_first = avd_first
  416. self.deep_stem = deep_stem
  417. self.stem_width = stem_width
  418. self.layers = layers
  419. self.final_drop = final_drop
  420. self.dilated = dilated
  421. self.dilation = dilation
  422. self.rectify_avg = rectify_avg
  423. if self.deep_stem:
  424. self.stem = nn.Sequential(
  425. ("conv1", ConvBNLayer(
  426. num_channels=3,
  427. num_filters=stem_width,
  428. filter_size=3,
  429. stride=2,
  430. act="relu",
  431. name="conv1")), ("conv2", ConvBNLayer(
  432. num_channels=stem_width,
  433. num_filters=stem_width,
  434. filter_size=3,
  435. stride=1,
  436. act="relu",
  437. name="conv2")), ("conv3", ConvBNLayer(
  438. num_channels=stem_width,
  439. num_filters=stem_width * 2,
  440. filter_size=3,
  441. stride=1,
  442. act="relu",
  443. name="conv3")))
  444. else:
  445. self.stem = ConvBNLayer(
  446. num_channels=3,
  447. num_filters=stem_width,
  448. filter_size=7,
  449. stride=2,
  450. act="relu",
  451. name="conv1")
  452. self.max_pool2d = MaxPool2D(kernel_size=3, stride=2, padding=1)
  453. self.layer1 = ResNeStLayer(
  454. inplanes=self.stem_width * 2
  455. if self.deep_stem else self.stem_width,
  456. planes=64,
  457. blocks=self.layers[0],
  458. radix=radix,
  459. cardinality=self.cardinality,
  460. bottleneck_width=bottleneck_width,
  461. avg_down=self.avg_down,
  462. avd=avd,
  463. avd_first=avd_first,
  464. rectify_avg=rectify_avg,
  465. last_gamma=last_gamma,
  466. stride=1,
  467. dilation=1,
  468. is_first=False,
  469. name="layer1")
  470. # return
  471. self.layer2 = ResNeStLayer(
  472. inplanes=256,
  473. planes=128,
  474. blocks=self.layers[1],
  475. radix=radix,
  476. cardinality=self.cardinality,
  477. bottleneck_width=bottleneck_width,
  478. avg_down=self.avg_down,
  479. avd=avd,
  480. avd_first=avd_first,
  481. rectify_avg=rectify_avg,
  482. last_gamma=last_gamma,
  483. stride=2,
  484. name="layer2")
  485. if self.dilated or self.dilation == 4:
  486. self.layer3 = ResNeStLayer(
  487. inplanes=512,
  488. planes=256,
  489. blocks=self.layers[2],
  490. radix=radix,
  491. cardinality=self.cardinality,
  492. bottleneck_width=bottleneck_width,
  493. avg_down=self.avg_down,
  494. avd=avd,
  495. avd_first=avd_first,
  496. rectify_avg=rectify_avg,
  497. last_gamma=last_gamma,
  498. stride=1,
  499. dilation=2,
  500. name="layer3")
  501. self.layer4 = ResNeStLayer(
  502. inplanes=1024,
  503. planes=512,
  504. blocks=self.layers[3],
  505. radix=radix,
  506. cardinality=self.cardinality,
  507. bottleneck_width=bottleneck_width,
  508. avg_down=self.avg_down,
  509. avd=avd,
  510. avd_first=avd_first,
  511. rectify_avg=rectify_avg,
  512. last_gamma=last_gamma,
  513. stride=1,
  514. dilation=4,
  515. name="layer4")
  516. elif self.dilation == 2:
  517. self.layer3 = ResNeStLayer(
  518. inplanes=512,
  519. planes=256,
  520. blocks=self.layers[2],
  521. radix=radix,
  522. cardinality=self.cardinality,
  523. bottleneck_width=bottleneck_width,
  524. avg_down=self.avg_down,
  525. avd=avd,
  526. avd_first=avd_first,
  527. rectify_avg=rectify_avg,
  528. last_gamma=last_gamma,
  529. stride=2,
  530. dilation=1,
  531. name="layer3")
  532. self.layer4 = ResNeStLayer(
  533. inplanes=1024,
  534. planes=512,
  535. blocks=self.layers[3],
  536. radix=radix,
  537. cardinality=self.cardinality,
  538. bottleneck_width=bottleneck_width,
  539. avg_down=self.avg_down,
  540. avd=avd,
  541. avd_first=avd_first,
  542. rectify_avg=rectify_avg,
  543. last_gamma=last_gamma,
  544. stride=1,
  545. dilation=2,
  546. name="layer4")
  547. else:
  548. self.layer3 = ResNeStLayer(
  549. inplanes=512,
  550. planes=256,
  551. blocks=self.layers[2],
  552. radix=radix,
  553. cardinality=self.cardinality,
  554. bottleneck_width=bottleneck_width,
  555. avg_down=self.avg_down,
  556. avd=avd,
  557. avd_first=avd_first,
  558. rectify_avg=rectify_avg,
  559. last_gamma=last_gamma,
  560. stride=2,
  561. name="layer3")
  562. self.layer4 = ResNeStLayer(
  563. inplanes=1024,
  564. planes=512,
  565. blocks=self.layers[3],
  566. radix=radix,
  567. cardinality=self.cardinality,
  568. bottleneck_width=bottleneck_width,
  569. avg_down=self.avg_down,
  570. avd=avd,
  571. avd_first=avd_first,
  572. rectify_avg=rectify_avg,
  573. last_gamma=last_gamma,
  574. stride=2,
  575. name="layer4")
  576. self.pool2d_avg = AdaptiveAvgPool2D(1)
  577. self.out_channels = 2048
  578. stdv = 1.0 / math.sqrt(self.out_channels * 1.0)
  579. self.out = Linear(
  580. self.out_channels,
  581. class_num,
  582. weight_attr=ParamAttr(
  583. initializer=nn.initializer.Uniform(-stdv, stdv),
  584. name="fc_weights"),
  585. bias_attr=ParamAttr(name="fc_offset"))
  586. def forward(self, x):
  587. x = self.stem(x)
  588. x = self.max_pool2d(x)
  589. x = self.layer1(x)
  590. x = self.layer2(x)
  591. x = self.layer3(x)
  592. x = self.layer4(x)
  593. x = self.pool2d_avg(x)
  594. x = paddle.reshape(x, shape=[-1, self.out_channels])
  595. x = self.out(x)
  596. return x
  597. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  598. if pretrained is False:
  599. pass
  600. elif pretrained is True:
  601. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  602. elif isinstance(pretrained, str):
  603. load_dygraph_pretrain(model, pretrained)
  604. else:
  605. raise RuntimeError(
  606. "pretrained type is not available. Please use `string` or `boolean` type."
  607. )
  608. def ResNeSt50_fast_1s1x64d(pretrained=False, use_ssld=False, **kwargs):
  609. model = ResNeSt(
  610. layers=[3, 4, 6, 3],
  611. radix=1,
  612. groups=1,
  613. bottleneck_width=64,
  614. deep_stem=True,
  615. stem_width=32,
  616. avg_down=True,
  617. avd=True,
  618. avd_first=True,
  619. final_drop=0.0,
  620. **kwargs)
  621. _load_pretrained(
  622. pretrained,
  623. model,
  624. MODEL_URLS["ResNeSt50_fast_1s1x64d"],
  625. use_ssld=use_ssld)
  626. return model
  627. def ResNeSt50(pretrained=False, use_ssld=False, **kwargs):
  628. model = ResNeSt(
  629. layers=[3, 4, 6, 3],
  630. radix=2,
  631. groups=1,
  632. bottleneck_width=64,
  633. deep_stem=True,
  634. stem_width=32,
  635. avg_down=True,
  636. avd=True,
  637. avd_first=False,
  638. final_drop=0.0,
  639. **kwargs)
  640. _load_pretrained(
  641. pretrained, model, MODEL_URLS["ResNeSt50"], use_ssld=use_ssld)
  642. return model
  643. def ResNeSt101(pretrained=False, use_ssld=False, **kwargs):
  644. model = ResNeSt(
  645. layers=[3, 4, 23, 3],
  646. radix=2,
  647. groups=1,
  648. bottleneck_width=64,
  649. deep_stem=True,
  650. stem_width=64,
  651. avg_down=True,
  652. avd=True,
  653. avd_first=False,
  654. final_drop=0.0,
  655. **kwargs)
  656. _load_pretrained(
  657. pretrained, model, MODEL_URLS["ResNeSt101"], use_ssld=use_ssld)
  658. return model