dpn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import sys
  19. import paddle
  20. from paddle import ParamAttr
  21. import paddle.nn as nn
  22. from paddle.nn import Conv2D, BatchNorm, Linear
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  27. MODEL_URLS = {
  28. "DPN68":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DPN68_pretrained.pdparams",
  30. "DPN92":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DPN92_pretrained.pdparams",
  32. "DPN98":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DPN98_pretrained.pdparams",
  34. "DPN107":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DPN107_pretrained.pdparams",
  36. "DPN131":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DPN131_pretrained.pdparams",
  38. }
  39. __all__ = list(MODEL_URLS.keys())
  40. class ConvBNLayer(nn.Layer):
  41. def __init__(self,
  42. num_channels,
  43. num_filters,
  44. filter_size,
  45. stride=1,
  46. pad=0,
  47. groups=1,
  48. act="relu",
  49. name=None):
  50. super(ConvBNLayer, self).__init__()
  51. self._conv = Conv2D(
  52. in_channels=num_channels,
  53. out_channels=num_filters,
  54. kernel_size=filter_size,
  55. stride=stride,
  56. padding=pad,
  57. groups=groups,
  58. weight_attr=ParamAttr(name=name + "_weights"),
  59. bias_attr=False)
  60. self._batch_norm = BatchNorm(
  61. num_filters,
  62. act=act,
  63. param_attr=ParamAttr(name=name + '_bn_scale'),
  64. bias_attr=ParamAttr(name + '_bn_offset'),
  65. moving_mean_name=name + '_bn_mean',
  66. moving_variance_name=name + '_bn_variance')
  67. def forward(self, input):
  68. y = self._conv(input)
  69. y = self._batch_norm(y)
  70. return y
  71. class BNACConvLayer(nn.Layer):
  72. def __init__(self,
  73. num_channels,
  74. num_filters,
  75. filter_size,
  76. stride=1,
  77. pad=0,
  78. groups=1,
  79. act="relu",
  80. name=None):
  81. super(BNACConvLayer, self).__init__()
  82. self.num_channels = num_channels
  83. self._batch_norm = BatchNorm(
  84. num_channels,
  85. act=act,
  86. param_attr=ParamAttr(name=name + '_bn_scale'),
  87. bias_attr=ParamAttr(name + '_bn_offset'),
  88. moving_mean_name=name + '_bn_mean',
  89. moving_variance_name=name + '_bn_variance')
  90. self._conv = Conv2D(
  91. in_channels=num_channels,
  92. out_channels=num_filters,
  93. kernel_size=filter_size,
  94. stride=stride,
  95. padding=pad,
  96. groups=groups,
  97. weight_attr=ParamAttr(name=name + "_weights"),
  98. bias_attr=False)
  99. def forward(self, input):
  100. y = self._batch_norm(input)
  101. y = self._conv(y)
  102. return y
  103. class DualPathFactory(nn.Layer):
  104. def __init__(self,
  105. num_channels,
  106. num_1x1_a,
  107. num_3x3_b,
  108. num_1x1_c,
  109. inc,
  110. G,
  111. _type='normal',
  112. name=None):
  113. super(DualPathFactory, self).__init__()
  114. self.num_1x1_c = num_1x1_c
  115. self.inc = inc
  116. self.name = name
  117. kw = 3
  118. kh = 3
  119. pw = (kw - 1) // 2
  120. ph = (kh - 1) // 2
  121. # type
  122. if _type == 'proj':
  123. key_stride = 1
  124. self.has_proj = True
  125. elif _type == 'down':
  126. key_stride = 2
  127. self.has_proj = True
  128. elif _type == 'normal':
  129. key_stride = 1
  130. self.has_proj = False
  131. else:
  132. print("not implemented now!!!")
  133. sys.exit(1)
  134. data_in_ch = sum(num_channels) if isinstance(num_channels,
  135. list) else num_channels
  136. if self.has_proj:
  137. self.c1x1_w_func = BNACConvLayer(
  138. num_channels=data_in_ch,
  139. num_filters=num_1x1_c + 2 * inc,
  140. filter_size=(1, 1),
  141. pad=(0, 0),
  142. stride=(key_stride, key_stride),
  143. name=name + "_match")
  144. self.c1x1_a_func = BNACConvLayer(
  145. num_channels=data_in_ch,
  146. num_filters=num_1x1_a,
  147. filter_size=(1, 1),
  148. pad=(0, 0),
  149. name=name + "_conv1")
  150. self.c3x3_b_func = BNACConvLayer(
  151. num_channels=num_1x1_a,
  152. num_filters=num_3x3_b,
  153. filter_size=(kw, kh),
  154. pad=(pw, ph),
  155. stride=(key_stride, key_stride),
  156. groups=G,
  157. name=name + "_conv2")
  158. self.c1x1_c_func = BNACConvLayer(
  159. num_channels=num_3x3_b,
  160. num_filters=num_1x1_c + inc,
  161. filter_size=(1, 1),
  162. pad=(0, 0),
  163. name=name + "_conv3")
  164. def forward(self, input):
  165. # PROJ
  166. if isinstance(input, list):
  167. data_in = paddle.concat([input[0], input[1]], axis=1)
  168. else:
  169. data_in = input
  170. if self.has_proj:
  171. c1x1_w = self.c1x1_w_func(data_in)
  172. data_o1, data_o2 = paddle.split(
  173. c1x1_w, num_or_sections=[self.num_1x1_c, 2 * self.inc], axis=1)
  174. else:
  175. data_o1 = input[0]
  176. data_o2 = input[1]
  177. c1x1_a = self.c1x1_a_func(data_in)
  178. c3x3_b = self.c3x3_b_func(c1x1_a)
  179. c1x1_c = self.c1x1_c_func(c3x3_b)
  180. c1x1_c1, c1x1_c2 = paddle.split(
  181. c1x1_c, num_or_sections=[self.num_1x1_c, self.inc], axis=1)
  182. # OUTPUTS
  183. summ = paddle.add(x=data_o1, y=c1x1_c1)
  184. dense = paddle.concat([data_o2, c1x1_c2], axis=1)
  185. # tensor, channels
  186. return [summ, dense]
  187. class DPN(nn.Layer):
  188. def __init__(self, layers=68, class_num=1000):
  189. super(DPN, self).__init__()
  190. self._class_num = class_num
  191. args = self.get_net_args(layers)
  192. bws = args['bw']
  193. inc_sec = args['inc_sec']
  194. rs = args['r']
  195. k_r = args['k_r']
  196. k_sec = args['k_sec']
  197. G = args['G']
  198. init_num_filter = args['init_num_filter']
  199. init_filter_size = args['init_filter_size']
  200. init_padding = args['init_padding']
  201. self.k_sec = k_sec
  202. self.conv1_x_1_func = ConvBNLayer(
  203. num_channels=3,
  204. num_filters=init_num_filter,
  205. filter_size=init_filter_size,
  206. stride=2,
  207. pad=init_padding,
  208. act='relu',
  209. name="conv1")
  210. self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
  211. num_channel_dpn = init_num_filter
  212. self.dpn_func_list = []
  213. #conv2 - conv5
  214. match_list, num = [], 0
  215. for gc in range(4):
  216. bw = bws[gc]
  217. inc = inc_sec[gc]
  218. R = (k_r * bw) // rs[gc]
  219. if gc == 0:
  220. _type1 = 'proj'
  221. _type2 = 'normal'
  222. match = 1
  223. else:
  224. _type1 = 'down'
  225. _type2 = 'normal'
  226. match = match + k_sec[gc - 1]
  227. match_list.append(match)
  228. self.dpn_func_list.append(
  229. self.add_sublayer(
  230. "dpn{}".format(match),
  231. DualPathFactory(
  232. num_channels=num_channel_dpn,
  233. num_1x1_a=R,
  234. num_3x3_b=R,
  235. num_1x1_c=bw,
  236. inc=inc,
  237. G=G,
  238. _type=_type1,
  239. name="dpn" + str(match))))
  240. num_channel_dpn = [bw, 3 * inc]
  241. for i_ly in range(2, k_sec[gc] + 1):
  242. num += 1
  243. if num in match_list:
  244. num += 1
  245. self.dpn_func_list.append(
  246. self.add_sublayer(
  247. "dpn{}".format(num),
  248. DualPathFactory(
  249. num_channels=num_channel_dpn,
  250. num_1x1_a=R,
  251. num_3x3_b=R,
  252. num_1x1_c=bw,
  253. inc=inc,
  254. G=G,
  255. _type=_type2,
  256. name="dpn" + str(num))))
  257. num_channel_dpn = [
  258. num_channel_dpn[0], num_channel_dpn[1] + inc
  259. ]
  260. out_channel = sum(num_channel_dpn)
  261. self.conv5_x_x_bn = BatchNorm(
  262. num_channels=sum(num_channel_dpn),
  263. act="relu",
  264. param_attr=ParamAttr(name='final_concat_bn_scale'),
  265. bias_attr=ParamAttr('final_concat_bn_offset'),
  266. moving_mean_name='final_concat_bn_mean',
  267. moving_variance_name='final_concat_bn_variance')
  268. self.pool2d_avg = AdaptiveAvgPool2D(1)
  269. stdv = 0.01
  270. self.out = Linear(
  271. out_channel,
  272. class_num,
  273. weight_attr=ParamAttr(
  274. initializer=Uniform(-stdv, stdv), name="fc_weights"),
  275. bias_attr=ParamAttr(name="fc_offset"))
  276. def forward(self, input):
  277. conv1_x_1 = self.conv1_x_1_func(input)
  278. convX_x_x = self.pool2d_max(conv1_x_1)
  279. dpn_idx = 0
  280. for gc in range(4):
  281. convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
  282. dpn_idx += 1
  283. for i_ly in range(2, self.k_sec[gc] + 1):
  284. convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
  285. dpn_idx += 1
  286. conv5_x_x = paddle.concat(convX_x_x, axis=1)
  287. conv5_x_x = self.conv5_x_x_bn(conv5_x_x)
  288. y = self.pool2d_avg(conv5_x_x)
  289. y = paddle.flatten(y, start_axis=1, stop_axis=-1)
  290. y = self.out(y)
  291. return y
  292. def get_net_args(self, layers):
  293. if layers == 68:
  294. k_r = 128
  295. G = 32
  296. k_sec = [3, 4, 12, 3]
  297. inc_sec = [16, 32, 32, 64]
  298. bw = [64, 128, 256, 512]
  299. r = [64, 64, 64, 64]
  300. init_num_filter = 10
  301. init_filter_size = 3
  302. init_padding = 1
  303. elif layers == 92:
  304. k_r = 96
  305. G = 32
  306. k_sec = [3, 4, 20, 3]
  307. inc_sec = [16, 32, 24, 128]
  308. bw = [256, 512, 1024, 2048]
  309. r = [256, 256, 256, 256]
  310. init_num_filter = 64
  311. init_filter_size = 7
  312. init_padding = 3
  313. elif layers == 98:
  314. k_r = 160
  315. G = 40
  316. k_sec = [3, 6, 20, 3]
  317. inc_sec = [16, 32, 32, 128]
  318. bw = [256, 512, 1024, 2048]
  319. r = [256, 256, 256, 256]
  320. init_num_filter = 96
  321. init_filter_size = 7
  322. init_padding = 3
  323. elif layers == 107:
  324. k_r = 200
  325. G = 50
  326. k_sec = [4, 8, 20, 3]
  327. inc_sec = [20, 64, 64, 128]
  328. bw = [256, 512, 1024, 2048]
  329. r = [256, 256, 256, 256]
  330. init_num_filter = 128
  331. init_filter_size = 7
  332. init_padding = 3
  333. elif layers == 131:
  334. k_r = 160
  335. G = 40
  336. k_sec = [4, 8, 28, 3]
  337. inc_sec = [16, 32, 32, 128]
  338. bw = [256, 512, 1024, 2048]
  339. r = [256, 256, 256, 256]
  340. init_num_filter = 128
  341. init_filter_size = 7
  342. init_padding = 3
  343. else:
  344. raise NotImplementedError
  345. net_arg = {
  346. 'k_r': k_r,
  347. 'G': G,
  348. 'k_sec': k_sec,
  349. 'inc_sec': inc_sec,
  350. 'bw': bw,
  351. 'r': r
  352. }
  353. net_arg['init_num_filter'] = init_num_filter
  354. net_arg['init_filter_size'] = init_filter_size
  355. net_arg['init_padding'] = init_padding
  356. return net_arg
  357. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  358. if pretrained is False:
  359. pass
  360. elif pretrained is True:
  361. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  362. elif isinstance(pretrained, str):
  363. load_dygraph_pretrain(model, pretrained)
  364. else:
  365. raise RuntimeError(
  366. "pretrained type is not available. Please use `string` or `boolean` type."
  367. )
  368. def DPN68(pretrained=False, use_ssld=False, **kwargs):
  369. model = DPN(layers=68, **kwargs)
  370. _load_pretrained(pretrained, model, MODEL_URLS["DPN68"])
  371. return model
  372. def DPN92(pretrained=False, use_ssld=False, **kwargs):
  373. model = DPN(layers=92, **kwargs)
  374. _load_pretrained(pretrained, model, MODEL_URLS["DPN92"])
  375. return model
  376. def DPN98(pretrained=False, use_ssld=False, **kwargs):
  377. model = DPN(layers=98, **kwargs)
  378. _load_pretrained(pretrained, model, MODEL_URLS["DPN98"])
  379. return model
  380. def DPN107(pretrained=False, use_ssld=False, **kwargs):
  381. model = DPN(layers=107, **kwargs)
  382. _load_pretrained(pretrained, model, MODEL_URLS["DPN107"])
  383. return model
  384. def DPN131(pretrained=False, use_ssld=False, **kwargs):
  385. model = DPN(layers=131, **kwargs)
  386. _load_pretrained(pretrained, model, MODEL_URLS["DPN131"])
  387. return model