regnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/facebookresearch/pycls
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import numpy as np
  19. import paddle
  20. from paddle import ParamAttr
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  24. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  25. from paddle.nn.initializer import Uniform
  26. import math
  27. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  28. MODEL_URLS = {
  29. "RegNetX_200MF":
  30. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_200MF_pretrained.pdparams",
  31. "RegNetX_4GF":
  32. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_4GF_pretrained.pdparams",
  33. "RegNetX_32GF":
  34. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetX_32GF_pretrained.pdparams",
  35. "RegNetY_200MF":
  36. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetY_200MF_pretrained.pdparams",
  37. "RegNetY_4GF":
  38. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetY_4GF_pretrained.pdparams",
  39. "RegNetY_32GF":
  40. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RegNetY_32GF_pretrained.pdparams",
  41. }
  42. __all__ = list(MODEL_URLS.keys())
  43. def quantize_float(f, q):
  44. """Converts a float to closest non-zero int divisible by q."""
  45. return int(round(f / q) * q)
  46. def adjust_ws_gs_comp(ws, bms, gs):
  47. """Adjusts the compatibility of widths and groups."""
  48. ws_bot = [int(w * b) for w, b in zip(ws, bms)]
  49. gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
  50. ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
  51. ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
  52. return ws, gs
  53. def get_stages_from_blocks(ws, rs):
  54. """Gets ws/ds of network at each stage from per block values."""
  55. ts = [
  56. w != wp or r != rp
  57. for w, wp, r, rp in zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
  58. ]
  59. s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
  60. s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
  61. return s_ws, s_ds
  62. def generate_regnet(w_a, w_0, w_m, d, q=8):
  63. """Generates per block ws from RegNet parameters."""
  64. assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
  65. ws_cont = np.arange(d) * w_a + w_0
  66. ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
  67. ws = w_0 * np.power(w_m, ks)
  68. ws = np.round(np.divide(ws, q)) * q
  69. num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
  70. ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
  71. return ws, num_stages, max_stage, ws_cont
  72. class ConvBNLayer(nn.Layer):
  73. def __init__(self,
  74. num_channels,
  75. num_filters,
  76. filter_size,
  77. stride=1,
  78. groups=1,
  79. padding=0,
  80. act=None,
  81. name=None):
  82. super(ConvBNLayer, self).__init__()
  83. self._conv = Conv2D(
  84. in_channels=num_channels,
  85. out_channels=num_filters,
  86. kernel_size=filter_size,
  87. stride=stride,
  88. padding=padding,
  89. groups=groups,
  90. weight_attr=ParamAttr(name=name + ".conv2d.output.1.w_0"),
  91. bias_attr=ParamAttr(name=name + ".conv2d.output.1.b_0"))
  92. bn_name = name + "_bn"
  93. self._batch_norm = BatchNorm(
  94. num_filters,
  95. act=act,
  96. param_attr=ParamAttr(name=bn_name + ".output.1.w_0"),
  97. bias_attr=ParamAttr(bn_name + ".output.1.b_0"),
  98. moving_mean_name=bn_name + "_mean",
  99. moving_variance_name=bn_name + "_variance")
  100. def forward(self, inputs):
  101. y = self._conv(inputs)
  102. y = self._batch_norm(y)
  103. return y
  104. class BottleneckBlock(nn.Layer):
  105. def __init__(self,
  106. num_channels,
  107. num_filters,
  108. stride,
  109. bm,
  110. gw,
  111. se_on,
  112. se_r,
  113. shortcut=True,
  114. name=None):
  115. super(BottleneckBlock, self).__init__()
  116. # Compute the bottleneck width
  117. w_b = int(round(num_filters * bm))
  118. # Compute the number of groups
  119. num_gs = w_b // gw
  120. self.se_on = se_on
  121. self.conv0 = ConvBNLayer(
  122. num_channels=num_channels,
  123. num_filters=w_b,
  124. filter_size=1,
  125. padding=0,
  126. act="relu",
  127. name=name + "_branch2a")
  128. self.conv1 = ConvBNLayer(
  129. num_channels=w_b,
  130. num_filters=w_b,
  131. filter_size=3,
  132. stride=stride,
  133. padding=1,
  134. groups=num_gs,
  135. act="relu",
  136. name=name + "_branch2b")
  137. if se_on:
  138. w_se = int(round(num_channels * se_r))
  139. self.se_block = SELayer(
  140. num_channels=w_b,
  141. num_filters=w_b,
  142. reduction_ratio=w_se,
  143. name=name + "_branch2se")
  144. self.conv2 = ConvBNLayer(
  145. num_channels=w_b,
  146. num_filters=num_filters,
  147. filter_size=1,
  148. act=None,
  149. name=name + "_branch2c")
  150. if not shortcut:
  151. self.short = ConvBNLayer(
  152. num_channels=num_channels,
  153. num_filters=num_filters,
  154. filter_size=1,
  155. stride=stride,
  156. name=name + "_branch1")
  157. self.shortcut = shortcut
  158. def forward(self, inputs):
  159. y = self.conv0(inputs)
  160. conv1 = self.conv1(y)
  161. if self.se_on:
  162. conv1 = self.se_block(conv1)
  163. conv2 = self.conv2(conv1)
  164. if self.shortcut:
  165. short = inputs
  166. else:
  167. short = self.short(inputs)
  168. y = paddle.add(x=short, y=conv2)
  169. y = F.relu(y)
  170. return y
  171. class SELayer(nn.Layer):
  172. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  173. super(SELayer, self).__init__()
  174. self.pool2d_gap = AdaptiveAvgPool2D(1)
  175. self._num_channels = num_channels
  176. med_ch = int(num_channels / reduction_ratio)
  177. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  178. self.squeeze = Linear(
  179. num_channels,
  180. med_ch,
  181. weight_attr=ParamAttr(
  182. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  183. bias_attr=ParamAttr(name=name + "_sqz_offset"))
  184. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  185. self.excitation = Linear(
  186. med_ch,
  187. num_filters,
  188. weight_attr=ParamAttr(
  189. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  190. bias_attr=ParamAttr(name=name + "_exc_offset"))
  191. def forward(self, input):
  192. pool = self.pool2d_gap(input)
  193. pool = paddle.reshape(pool, shape=[-1, self._num_channels])
  194. squeeze = self.squeeze(pool)
  195. squeeze = F.relu(squeeze)
  196. excitation = self.excitation(squeeze)
  197. excitation = F.sigmoid(excitation)
  198. excitation = paddle.reshape(
  199. excitation, shape=[-1, self._num_channels, 1, 1])
  200. out = input * excitation
  201. return out
  202. class RegNet(nn.Layer):
  203. def __init__(self,
  204. w_a,
  205. w_0,
  206. w_m,
  207. d,
  208. group_w,
  209. bot_mul,
  210. q=8,
  211. se_on=False,
  212. class_num=1000):
  213. super(RegNet, self).__init__()
  214. # Generate RegNet ws per block
  215. b_ws, num_s, max_s, ws_cont = generate_regnet(w_a, w_0, w_m, d, q)
  216. # Convert to per stage format
  217. ws, ds = get_stages_from_blocks(b_ws, b_ws)
  218. # Generate group widths and bot muls
  219. gws = [group_w for _ in range(num_s)]
  220. bms = [bot_mul for _ in range(num_s)]
  221. # Adjust the compatibility of ws and gws
  222. ws, gws = adjust_ws_gs_comp(ws, bms, gws)
  223. # Use the same stride for each stage
  224. ss = [2 for _ in range(num_s)]
  225. # Use SE for RegNetY
  226. se_r = 0.25
  227. # Construct the model
  228. # Group params by stage
  229. stage_params = list(zip(ds, ws, ss, bms, gws))
  230. # Construct the stem
  231. stem_type = "simple_stem_in"
  232. stem_w = 32
  233. block_type = "res_bottleneck_block"
  234. self.conv = ConvBNLayer(
  235. num_channels=3,
  236. num_filters=stem_w,
  237. filter_size=3,
  238. stride=2,
  239. padding=1,
  240. act="relu",
  241. name="stem_conv")
  242. self.block_list = []
  243. for block, (d, w_out, stride, bm, gw) in enumerate(stage_params):
  244. shortcut = False
  245. for i in range(d):
  246. num_channels = stem_w if block == i == 0 else in_channels
  247. # Stride apply to the first block of the stage
  248. b_stride = stride if i == 0 else 1
  249. conv_name = "s" + str(block + 1) + "_b" + str(i +
  250. 1) # chr(97 + i)
  251. bottleneck_block = self.add_sublayer(
  252. conv_name,
  253. BottleneckBlock(
  254. num_channels=num_channels,
  255. num_filters=w_out,
  256. stride=b_stride,
  257. bm=bm,
  258. gw=gw,
  259. se_on=se_on,
  260. se_r=se_r,
  261. shortcut=shortcut,
  262. name=conv_name))
  263. in_channels = w_out
  264. self.block_list.append(bottleneck_block)
  265. shortcut = True
  266. self.pool2d_avg = AdaptiveAvgPool2D(1)
  267. self.pool2d_avg_channels = w_out
  268. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  269. self.out = Linear(
  270. self.pool2d_avg_channels,
  271. class_num,
  272. weight_attr=ParamAttr(
  273. initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
  274. bias_attr=ParamAttr(name="fc_0.b_0"))
  275. def forward(self, inputs):
  276. y = self.conv(inputs)
  277. for block in self.block_list:
  278. y = block(y)
  279. y = self.pool2d_avg(y)
  280. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  281. y = self.out(y)
  282. return y
  283. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  284. if pretrained is False:
  285. pass
  286. elif pretrained is True:
  287. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  288. elif isinstance(pretrained, str):
  289. load_dygraph_pretrain(model, pretrained)
  290. else:
  291. raise RuntimeError(
  292. "pretrained type is not available. Please use `string` or `boolean` type."
  293. )
  294. def RegNetX_200MF(pretrained=False, use_ssld=False, **kwargs):
  295. model = RegNet(
  296. w_a=36.44,
  297. w_0=24,
  298. w_m=2.49,
  299. d=13,
  300. group_w=8,
  301. bot_mul=1.0,
  302. q=8,
  303. **kwargs)
  304. _load_pretrained(
  305. pretrained, model, MODEL_URLS["RegNetX_200MF"], use_ssld=use_ssld)
  306. return model
  307. def RegNetX_4GF(pretrained=False, use_ssld=False, **kwargs):
  308. model = RegNet(
  309. w_a=38.65,
  310. w_0=96,
  311. w_m=2.43,
  312. d=23,
  313. group_w=40,
  314. bot_mul=1.0,
  315. q=8,
  316. **kwargs)
  317. _load_pretrained(
  318. pretrained, model, MODEL_URLS["RegNetX_4GF"], use_ssld=use_ssld)
  319. return model
  320. def RegNetX_32GF(pretrained=False, use_ssld=False, **kwargs):
  321. model = RegNet(
  322. w_a=69.86,
  323. w_0=320,
  324. w_m=2.0,
  325. d=23,
  326. group_w=168,
  327. bot_mul=1.0,
  328. q=8,
  329. **kwargs)
  330. _load_pretrained(
  331. pretrained, model, MODEL_URLS["RegNetX_32GF"], use_ssld=use_ssld)
  332. return model
  333. def RegNetY_200MF(pretrained=False, use_ssld=False, **kwargs):
  334. model = RegNet(
  335. w_a=36.44,
  336. w_0=24,
  337. w_m=2.49,
  338. d=13,
  339. group_w=8,
  340. bot_mul=1.0,
  341. q=8,
  342. se_on=True,
  343. **kwargs)
  344. _load_pretrained(
  345. pretrained, model, MODEL_URLS["RegNetX_32GF"], use_ssld=use_ssld)
  346. return model
  347. def RegNetY_4GF(pretrained=False, use_ssld=False, **kwargs):
  348. model = RegNet(
  349. w_a=31.41,
  350. w_0=96,
  351. w_m=2.24,
  352. d=22,
  353. group_w=64,
  354. bot_mul=1.0,
  355. q=8,
  356. se_on=True,
  357. **kwargs)
  358. _load_pretrained(
  359. pretrained, model, MODEL_URLS["RegNetX_32GF"], use_ssld=use_ssld)
  360. return model
  361. def RegNetY_32GF(pretrained=False, use_ssld=False, **kwargs):
  362. model = RegNet(
  363. w_a=115.89,
  364. w_0=232,
  365. w_m=2.53,
  366. d=20,
  367. group_w=232,
  368. bot_mul=1.0,
  369. q=8,
  370. se_on=True,
  371. **kwargs)
  372. _load_pretrained(
  373. pretrained, model, MODEL_URLS["RegNetX_32GF"], use_ssld=use_ssld)
  374. return model