resnet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import, division, print_function
  15. import numpy as np
  16. import paddle
  17. from paddle import ParamAttr
  18. import paddle.nn as nn
  19. from paddle.nn import Conv2D, BatchNorm, Linear
  20. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  21. from paddle.nn.initializer import Uniform
  22. import math
  23. from paddlex.ppcls.arch.backbone.base.theseus_layer import TheseusLayer
  24. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  25. MODEL_URLS = {
  26. "ResNet18":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
  28. "ResNet18_vd":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
  30. "ResNet34":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
  32. "ResNet34_vd":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
  34. "ResNet50":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
  36. "ResNet50_vd":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
  38. "ResNet101":
  39. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
  40. "ResNet101_vd":
  41. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
  42. "ResNet152":
  43. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
  44. "ResNet152_vd":
  45. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
  46. "ResNet200_vd":
  47. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
  48. }
  49. __all__ = MODEL_URLS.keys()
  50. '''
  51. ResNet config: dict.
  52. key: depth of ResNet.
  53. values: config's dict of specific model.
  54. keys:
  55. block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
  56. block_depth: The number of blocks in different stages in ResNet.
  57. num_channels: The number of channels to enter the next stage.
  58. '''
  59. NET_CONFIG = {
  60. "18": {
  61. "block_type": "BasicBlock",
  62. "block_depth": [2, 2, 2, 2],
  63. "num_channels": [64, 64, 128, 256]
  64. },
  65. "34": {
  66. "block_type": "BasicBlock",
  67. "block_depth": [3, 4, 6, 3],
  68. "num_channels": [64, 64, 128, 256]
  69. },
  70. "50": {
  71. "block_type": "BottleneckBlock",
  72. "block_depth": [3, 4, 6, 3],
  73. "num_channels": [64, 256, 512, 1024]
  74. },
  75. "101": {
  76. "block_type": "BottleneckBlock",
  77. "block_depth": [3, 4, 23, 3],
  78. "num_channels": [64, 256, 512, 1024]
  79. },
  80. "152": {
  81. "block_type": "BottleneckBlock",
  82. "block_depth": [3, 8, 36, 3],
  83. "num_channels": [64, 256, 512, 1024]
  84. },
  85. "200": {
  86. "block_type": "BottleneckBlock",
  87. "block_depth": [3, 12, 48, 3],
  88. "num_channels": [64, 256, 512, 1024]
  89. },
  90. }
  91. class ConvBNLayer(TheseusLayer):
  92. def __init__(self,
  93. num_channels,
  94. num_filters,
  95. filter_size,
  96. stride=1,
  97. groups=1,
  98. is_vd_mode=False,
  99. act=None,
  100. lr_mult=1.0,
  101. data_format="NCHW"):
  102. super().__init__()
  103. self.is_vd_mode = is_vd_mode
  104. self.act = act
  105. self.avg_pool = AvgPool2D(
  106. kernel_size=2, stride=2, padding=0, ceil_mode=True)
  107. self.conv = Conv2D(
  108. in_channels=num_channels,
  109. out_channels=num_filters,
  110. kernel_size=filter_size,
  111. stride=stride,
  112. padding=(filter_size - 1) // 2,
  113. groups=groups,
  114. weight_attr=ParamAttr(learning_rate=lr_mult),
  115. bias_attr=False,
  116. data_format=data_format)
  117. self.bn = BatchNorm(
  118. num_filters,
  119. param_attr=ParamAttr(learning_rate=lr_mult),
  120. bias_attr=ParamAttr(learning_rate=lr_mult),
  121. data_layout=data_format)
  122. self.relu = nn.ReLU()
  123. def forward(self, x):
  124. if self.is_vd_mode:
  125. x = self.avg_pool(x)
  126. x = self.conv(x)
  127. x = self.bn(x)
  128. if self.act:
  129. x = self.relu(x)
  130. return x
  131. class BottleneckBlock(TheseusLayer):
  132. def __init__(self,
  133. num_channels,
  134. num_filters,
  135. stride,
  136. shortcut=True,
  137. if_first=False,
  138. lr_mult=1.0,
  139. data_format="NCHW"):
  140. super().__init__()
  141. self.conv0 = ConvBNLayer(
  142. num_channels=num_channels,
  143. num_filters=num_filters,
  144. filter_size=1,
  145. act="relu",
  146. lr_mult=lr_mult,
  147. data_format=data_format)
  148. self.conv1 = ConvBNLayer(
  149. num_channels=num_filters,
  150. num_filters=num_filters,
  151. filter_size=3,
  152. stride=stride,
  153. act="relu",
  154. lr_mult=lr_mult,
  155. data_format=data_format)
  156. self.conv2 = ConvBNLayer(
  157. num_channels=num_filters,
  158. num_filters=num_filters * 4,
  159. filter_size=1,
  160. act=None,
  161. lr_mult=lr_mult,
  162. data_format=data_format)
  163. if not shortcut:
  164. self.short = ConvBNLayer(
  165. num_channels=num_channels,
  166. num_filters=num_filters * 4,
  167. filter_size=1,
  168. stride=stride if if_first else 1,
  169. is_vd_mode=False if if_first else True,
  170. lr_mult=lr_mult,
  171. data_format=data_format)
  172. self.relu = nn.ReLU()
  173. self.shortcut = shortcut
  174. def forward(self, x):
  175. identity = x
  176. x = self.conv0(x)
  177. x = self.conv1(x)
  178. x = self.conv2(x)
  179. if self.shortcut:
  180. short = identity
  181. else:
  182. short = self.short(identity)
  183. x = paddle.add(x=x, y=short)
  184. x = self.relu(x)
  185. return x
  186. class BasicBlock(TheseusLayer):
  187. def __init__(self,
  188. num_channels,
  189. num_filters,
  190. stride,
  191. shortcut=True,
  192. if_first=False,
  193. lr_mult=1.0,
  194. data_format="NCHW"):
  195. super().__init__()
  196. self.stride = stride
  197. self.conv0 = ConvBNLayer(
  198. num_channels=num_channels,
  199. num_filters=num_filters,
  200. filter_size=3,
  201. stride=stride,
  202. act="relu",
  203. lr_mult=lr_mult,
  204. data_format=data_format)
  205. self.conv1 = ConvBNLayer(
  206. num_channels=num_filters,
  207. num_filters=num_filters,
  208. filter_size=3,
  209. act=None,
  210. lr_mult=lr_mult,
  211. data_format=data_format)
  212. if not shortcut:
  213. self.short = ConvBNLayer(
  214. num_channels=num_channels,
  215. num_filters=num_filters,
  216. filter_size=1,
  217. stride=stride if if_first else 1,
  218. is_vd_mode=False if if_first else True,
  219. lr_mult=lr_mult,
  220. data_format=data_format)
  221. self.shortcut = shortcut
  222. self.relu = nn.ReLU()
  223. def forward(self, x):
  224. identity = x
  225. x = self.conv0(x)
  226. x = self.conv1(x)
  227. if self.shortcut:
  228. short = identity
  229. else:
  230. short = self.short(identity)
  231. x = paddle.add(x=x, y=short)
  232. x = self.relu(x)
  233. return x
  234. class ResNet(TheseusLayer):
  235. """
  236. ResNet
  237. Args:
  238. config: dict. config of ResNet.
  239. version: str="vb". Different version of ResNet, version vd can perform better.
  240. class_num: int=1000. The number of classes.
  241. lr_mult_list: list. Control the learning rate of different stages.
  242. Returns:
  243. model: nn.Layer. Specific ResNet model depends on args.
  244. """
  245. def __init__(self,
  246. config,
  247. version="vb",
  248. class_num=1000,
  249. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  250. data_format="NCHW",
  251. input_image_channel=3,
  252. return_patterns=None):
  253. super().__init__()
  254. self.cfg = config
  255. self.lr_mult_list = lr_mult_list
  256. self.is_vd_mode = version == "vd"
  257. self.class_num = class_num
  258. self.num_filters = [64, 128, 256, 512]
  259. self.block_depth = self.cfg["block_depth"]
  260. self.block_type = self.cfg["block_type"]
  261. self.num_channels = self.cfg["num_channels"]
  262. self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
  263. assert isinstance(self.lr_mult_list, (
  264. list, tuple
  265. )), "lr_mult_list should be in (list, tuple) but got {}".format(
  266. type(self.lr_mult_list))
  267. assert len(self.lr_mult_list
  268. ) == 5, "lr_mult_list length should be 5 but got {}".format(
  269. len(self.lr_mult_list))
  270. self.stem_cfg = {
  271. #num_channels, num_filters, filter_size, stride
  272. "vb": [[input_image_channel, 64, 7, 2]],
  273. "vd":
  274. [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
  275. }
  276. self.stem = nn.Sequential(*[
  277. ConvBNLayer(
  278. num_channels=in_c,
  279. num_filters=out_c,
  280. filter_size=k,
  281. stride=s,
  282. act="relu",
  283. lr_mult=self.lr_mult_list[0],
  284. data_format=data_format)
  285. for in_c, out_c, k, s in self.stem_cfg[version]
  286. ])
  287. self.max_pool = MaxPool2D(
  288. kernel_size=3, stride=2, padding=1, data_format=data_format)
  289. block_list = []
  290. for block_idx in range(len(self.block_depth)):
  291. shortcut = False
  292. for i in range(self.block_depth[block_idx]):
  293. block_list.append(globals()[self.block_type](
  294. num_channels=self.num_channels[block_idx] if i == 0 else
  295. self.num_filters[block_idx] * self.channels_mult,
  296. num_filters=self.num_filters[block_idx],
  297. stride=2 if i == 0 and block_idx != 0 else 1,
  298. shortcut=shortcut,
  299. if_first=block_idx == i == 0 if version == "vd" else True,
  300. lr_mult=self.lr_mult_list[block_idx + 1],
  301. data_format=data_format))
  302. shortcut = True
  303. self.blocks = nn.Sequential(*block_list)
  304. self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
  305. self.flatten = nn.Flatten()
  306. self.avg_pool_channels = self.num_channels[-1] * 2
  307. stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
  308. self.fc = Linear(
  309. self.avg_pool_channels,
  310. self.class_num,
  311. weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
  312. self.data_format = data_format
  313. if return_patterns is not None:
  314. self.update_res(return_patterns)
  315. self.register_forward_post_hook(self._return_dict_hook)
  316. def forward(self, x):
  317. with paddle.static.amp.fp16_guard():
  318. if self.data_format == "NHWC":
  319. x = paddle.transpose(x, [0, 2, 3, 1])
  320. x.stop_gradient = True
  321. x = self.stem(x)
  322. x = self.max_pool(x)
  323. x = self.blocks(x)
  324. x = self.avg_pool(x)
  325. x = self.flatten(x)
  326. x = self.fc(x)
  327. return x
  328. def _load_pretrained(pretrained, model, model_url, use_ssld):
  329. if pretrained is False:
  330. pass
  331. elif pretrained is True:
  332. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  333. elif isinstance(pretrained, str):
  334. load_dygraph_pretrain(model, pretrained)
  335. else:
  336. raise RuntimeError(
  337. "pretrained type is not available. Please use `string` or `boolean` type."
  338. )
  339. def ResNet18(pretrained=False, use_ssld=False, **kwargs):
  340. """
  341. ResNet18
  342. Args:
  343. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  344. If str, means the path of the pretrained model.
  345. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  346. Returns:
  347. model: nn.Layer. Specific `ResNet18` model depends on args.
  348. """
  349. model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
  350. _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
  351. return model
  352. def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
  353. """
  354. ResNet18_vd
  355. Args:
  356. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  357. If str, means the path of the pretrained model.
  358. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  359. Returns:
  360. model: nn.Layer. Specific `ResNet18_vd` model depends on args.
  361. """
  362. model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
  363. _load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
  364. return model
  365. def ResNet34(pretrained=False, use_ssld=False, **kwargs):
  366. """
  367. ResNet34
  368. Args:
  369. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  370. If str, means the path of the pretrained model.
  371. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  372. Returns:
  373. model: nn.Layer. Specific `ResNet34` model depends on args.
  374. """
  375. model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
  376. _load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
  377. return model
  378. def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
  379. """
  380. ResNet34_vd
  381. Args:
  382. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  383. If str, means the path of the pretrained model.
  384. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  385. Returns:
  386. model: nn.Layer. Specific `ResNet34_vd` model depends on args.
  387. """
  388. model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
  389. _load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
  390. return model
  391. def ResNet50(pretrained=False, use_ssld=False, **kwargs):
  392. """
  393. ResNet50
  394. Args:
  395. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  396. If str, means the path of the pretrained model.
  397. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  398. Returns:
  399. model: nn.Layer. Specific `ResNet50` model depends on args.
  400. """
  401. model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
  402. _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
  403. return model
  404. def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
  405. """
  406. ResNet50_vd
  407. Args:
  408. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  409. If str, means the path of the pretrained model.
  410. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  411. Returns:
  412. model: nn.Layer. Specific `ResNet50_vd` model depends on args.
  413. """
  414. model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
  415. _load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
  416. return model
  417. def ResNet101(pretrained=False, use_ssld=False, **kwargs):
  418. """
  419. ResNet101
  420. Args:
  421. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  422. If str, means the path of the pretrained model.
  423. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  424. Returns:
  425. model: nn.Layer. Specific `ResNet101` model depends on args.
  426. """
  427. model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
  428. _load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
  429. return model
  430. def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
  431. """
  432. ResNet101_vd
  433. Args:
  434. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  435. If str, means the path of the pretrained model.
  436. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  437. Returns:
  438. model: nn.Layer. Specific `ResNet101_vd` model depends on args.
  439. """
  440. model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
  441. _load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
  442. return model
  443. def ResNet152(pretrained=False, use_ssld=False, **kwargs):
  444. """
  445. ResNet152
  446. Args:
  447. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  448. If str, means the path of the pretrained model.
  449. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  450. Returns:
  451. model: nn.Layer. Specific `ResNet152` model depends on args.
  452. """
  453. model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
  454. _load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
  455. return model
  456. def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
  457. """
  458. ResNet152_vd
  459. Args:
  460. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  461. If str, means the path of the pretrained model.
  462. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  463. Returns:
  464. model: nn.Layer. Specific `ResNet152_vd` model depends on args.
  465. """
  466. model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
  467. _load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
  468. return model
  469. def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
  470. """
  471. ResNet200_vd
  472. Args:
  473. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  474. If str, means the path of the pretrained model.
  475. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  476. Returns:
  477. model: nn.Layer. Specific `ResNet200_vd` model depends on args.
  478. """
  479. model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
  480. _load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
  481. return model