levit.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/facebookresearch/LeViT
  15. import itertools
  16. import math
  17. import warnings
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn.initializer import TruncatedNormal, Constant
  22. from paddle.regularizer import L2Decay
  23. from .vision_transformer import trunc_normal_, zeros_, ones_, Identity
  24. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  25. MODEL_URLS = {
  26. "LeViT_128S":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/LeViT_128S_pretrained.pdparams",
  28. "LeViT_128":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/LeViT_128_pretrained.pdparams",
  30. "LeViT_192":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/LeViT_192_pretrained.pdparams",
  32. "LeViT_256":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/LeViT_256_pretrained.pdparams",
  34. "LeViT_384":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/LeViT_384_pretrained.pdparams",
  36. }
  37. __all__ = list(MODEL_URLS.keys())
  38. def cal_attention_biases(attention_biases, attention_bias_idxs):
  39. gather_list = []
  40. attention_bias_t = paddle.transpose(attention_biases, (1, 0))
  41. nums = attention_bias_idxs.shape[0]
  42. for idx in range(nums):
  43. gather = paddle.gather(attention_bias_t, attention_bias_idxs[idx])
  44. gather_list.append(gather)
  45. shape0, shape1 = attention_bias_idxs.shape
  46. gather = paddle.concat(gather_list)
  47. return paddle.transpose(gather, (1, 0)).reshape((0, shape0, shape1))
  48. class Conv2d_BN(nn.Sequential):
  49. def __init__(self,
  50. a,
  51. b,
  52. ks=1,
  53. stride=1,
  54. pad=0,
  55. dilation=1,
  56. groups=1,
  57. bn_weight_init=1,
  58. resolution=-10000):
  59. super().__init__()
  60. self.add_sublayer(
  61. 'c',
  62. nn.Conv2D(
  63. a, b, ks, stride, pad, dilation, groups, bias_attr=False))
  64. bn = nn.BatchNorm2D(b)
  65. ones_(bn.weight)
  66. zeros_(bn.bias)
  67. self.add_sublayer('bn', bn)
  68. class Linear_BN(nn.Sequential):
  69. def __init__(self, a, b, bn_weight_init=1):
  70. super().__init__()
  71. self.add_sublayer('c', nn.Linear(a, b, bias_attr=False))
  72. bn = nn.BatchNorm1D(b)
  73. if bn_weight_init == 0:
  74. zeros_(bn.weight)
  75. else:
  76. ones_(bn.weight)
  77. zeros_(bn.bias)
  78. self.add_sublayer('bn', bn)
  79. def forward(self, x):
  80. l, bn = self._sub_layers.values()
  81. x = l(x)
  82. return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
  83. class BN_Linear(nn.Sequential):
  84. def __init__(self, a, b, bias=True, std=0.02):
  85. super().__init__()
  86. self.add_sublayer('bn', nn.BatchNorm1D(a))
  87. l = nn.Linear(a, b, bias_attr=bias)
  88. trunc_normal_(l.weight)
  89. if bias:
  90. zeros_(l.bias)
  91. self.add_sublayer('l', l)
  92. def b16(n, activation, resolution=224):
  93. return nn.Sequential(
  94. Conv2d_BN(
  95. 3, n // 8, 3, 2, 1, resolution=resolution),
  96. activation(),
  97. Conv2d_BN(
  98. n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
  99. activation(),
  100. Conv2d_BN(
  101. n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
  102. activation(),
  103. Conv2d_BN(
  104. n // 2, n, 3, 2, 1, resolution=resolution // 8))
  105. class Residual(nn.Layer):
  106. def __init__(self, m, drop):
  107. super().__init__()
  108. self.m = m
  109. self.drop = drop
  110. def forward(self, x):
  111. if self.training and self.drop > 0:
  112. y = paddle.rand(
  113. shape=[x.shape[0], 1, 1]).__ge__(self.drop).astype("float32")
  114. y = y.divide(paddle.full_like(y, 1 - self.drop))
  115. return paddle.add(x, y)
  116. else:
  117. return paddle.add(x, self.m(x))
  118. class Attention(nn.Layer):
  119. def __init__(self,
  120. dim,
  121. key_dim,
  122. num_heads=8,
  123. attn_ratio=4,
  124. activation=None,
  125. resolution=14):
  126. super().__init__()
  127. self.num_heads = num_heads
  128. self.scale = key_dim**-0.5
  129. self.key_dim = key_dim
  130. self.nh_kd = nh_kd = key_dim * num_heads
  131. self.d = int(attn_ratio * key_dim)
  132. self.dh = int(attn_ratio * key_dim) * num_heads
  133. self.attn_ratio = attn_ratio
  134. self.h = self.dh + nh_kd * 2
  135. self.qkv = Linear_BN(dim, self.h)
  136. self.proj = nn.Sequential(
  137. activation(), Linear_BN(
  138. self.dh, dim, bn_weight_init=0))
  139. points = list(itertools.product(range(resolution), range(resolution)))
  140. N = len(points)
  141. attention_offsets = {}
  142. idxs = []
  143. for p1 in points:
  144. for p2 in points:
  145. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  146. if offset not in attention_offsets:
  147. attention_offsets[offset] = len(attention_offsets)
  148. idxs.append(attention_offsets[offset])
  149. self.attention_biases = self.create_parameter(
  150. shape=(num_heads, len(attention_offsets)),
  151. default_initializer=zeros_,
  152. attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
  153. tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
  154. self.register_buffer('attention_bias_idxs',
  155. paddle.reshape(tensor_idxs, [N, N]))
  156. @paddle.no_grad()
  157. def train(self, mode=True):
  158. if mode:
  159. super().train()
  160. else:
  161. super().eval()
  162. if mode and hasattr(self, 'ab'):
  163. del self.ab
  164. else:
  165. self.ab = cal_attention_biases(self.attention_biases,
  166. self.attention_bias_idxs)
  167. def forward(self, x):
  168. self.training = True
  169. B, N, C = x.shape
  170. qkv = self.qkv(x)
  171. qkv = paddle.reshape(qkv,
  172. [B, N, self.num_heads, self.h // self.num_heads])
  173. q, k, v = paddle.split(
  174. qkv, [self.key_dim, self.key_dim, self.d], axis=3)
  175. q = paddle.transpose(q, perm=[0, 2, 1, 3])
  176. k = paddle.transpose(k, perm=[0, 2, 1, 3])
  177. v = paddle.transpose(v, perm=[0, 2, 1, 3])
  178. k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
  179. if self.training:
  180. attention_biases = cal_attention_biases(self.attention_biases,
  181. self.attention_bias_idxs)
  182. else:
  183. attention_biases = self.ab
  184. attn = (paddle.matmul(q, k_transpose) * self.scale + attention_biases)
  185. attn = F.softmax(attn)
  186. x = paddle.transpose(paddle.matmul(attn, v), perm=[0, 2, 1, 3])
  187. x = paddle.reshape(x, [B, N, self.dh])
  188. x = self.proj(x)
  189. return x
  190. class Subsample(nn.Layer):
  191. def __init__(self, stride, resolution):
  192. super().__init__()
  193. self.stride = stride
  194. self.resolution = resolution
  195. def forward(self, x):
  196. B, N, C = x.shape
  197. x = paddle.reshape(x, [B, self.resolution, self.resolution, C])
  198. end1, end2 = x.shape[1], x.shape[2]
  199. x = x[:, 0:end1:self.stride, 0:end2:self.stride]
  200. x = paddle.reshape(x, [B, -1, C])
  201. return x
  202. class AttentionSubsample(nn.Layer):
  203. def __init__(self,
  204. in_dim,
  205. out_dim,
  206. key_dim,
  207. num_heads=8,
  208. attn_ratio=2,
  209. activation=None,
  210. stride=2,
  211. resolution=14,
  212. resolution_=7):
  213. super().__init__()
  214. self.num_heads = num_heads
  215. self.scale = key_dim**-0.5
  216. self.key_dim = key_dim
  217. self.nh_kd = nh_kd = key_dim * num_heads
  218. self.d = int(attn_ratio * key_dim)
  219. self.dh = int(attn_ratio * key_dim) * self.num_heads
  220. self.attn_ratio = attn_ratio
  221. self.resolution_ = resolution_
  222. self.resolution_2 = resolution_**2
  223. self.training = True
  224. h = self.dh + nh_kd
  225. self.kv = Linear_BN(in_dim, h)
  226. self.q = nn.Sequential(
  227. Subsample(stride, resolution), Linear_BN(in_dim, nh_kd))
  228. self.proj = nn.Sequential(activation(), Linear_BN(self.dh, out_dim))
  229. self.stride = stride
  230. self.resolution = resolution
  231. points = list(itertools.product(range(resolution), range(resolution)))
  232. points_ = list(
  233. itertools.product(range(resolution_), range(resolution_)))
  234. N = len(points)
  235. N_ = len(points_)
  236. attention_offsets = {}
  237. idxs = []
  238. i = 0
  239. j = 0
  240. for p1 in points_:
  241. i += 1
  242. for p2 in points:
  243. j += 1
  244. size = 1
  245. offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
  246. abs(p1[1] * stride - p2[1] + (size - 1) / 2))
  247. if offset not in attention_offsets:
  248. attention_offsets[offset] = len(attention_offsets)
  249. idxs.append(attention_offsets[offset])
  250. self.attention_biases = self.create_parameter(
  251. shape=(num_heads, len(attention_offsets)),
  252. default_initializer=zeros_,
  253. attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
  254. tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
  255. self.register_buffer('attention_bias_idxs',
  256. paddle.reshape(tensor_idxs_, [N_, N]))
  257. @paddle.no_grad()
  258. def train(self, mode=True):
  259. if mode:
  260. super().train()
  261. else:
  262. super().eval()
  263. if mode and hasattr(self, 'ab'):
  264. del self.ab
  265. else:
  266. self.ab = cal_attention_biases(self.attention_biases,
  267. self.attention_bias_idxs)
  268. def forward(self, x):
  269. self.training = True
  270. B, N, C = x.shape
  271. kv = self.kv(x)
  272. kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
  273. k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
  274. k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
  275. v = paddle.transpose(v, perm=[0, 2, 1, 3])
  276. q = paddle.reshape(
  277. self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
  278. q = paddle.transpose(q, perm=[0, 2, 1, 3])
  279. if self.training:
  280. attention_biases = cal_attention_biases(self.attention_biases,
  281. self.attention_bias_idxs)
  282. else:
  283. attention_biases = self.ab
  284. attn = (paddle.matmul(
  285. q, paddle.transpose(
  286. k, perm=[0, 1, 3, 2]))) * self.scale + attention_biases
  287. attn = F.softmax(attn)
  288. x = paddle.reshape(
  289. paddle.transpose(
  290. paddle.matmul(attn, v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
  291. x = self.proj(x)
  292. return x
  293. class LeViT(nn.Layer):
  294. """ Vision Transformer with support for patch or hybrid CNN input stage
  295. """
  296. def __init__(self,
  297. img_size=224,
  298. patch_size=16,
  299. in_chans=3,
  300. class_num=1000,
  301. embed_dim=[192],
  302. key_dim=[64],
  303. depth=[12],
  304. num_heads=[3],
  305. attn_ratio=[2],
  306. mlp_ratio=[2],
  307. hybrid_backbone=None,
  308. down_ops=[],
  309. attention_activation=nn.Hardswish,
  310. mlp_activation=nn.Hardswish,
  311. distillation=True,
  312. drop_path=0):
  313. super().__init__()
  314. self.class_num = class_num
  315. self.num_features = embed_dim[-1]
  316. self.embed_dim = embed_dim
  317. self.distillation = distillation
  318. self.patch_embed = hybrid_backbone
  319. self.blocks = []
  320. down_ops.append([''])
  321. resolution = img_size // patch_size
  322. for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
  323. zip(embed_dim, key_dim, depth, num_heads, attn_ratio,
  324. mlp_ratio, down_ops)):
  325. for _ in range(dpth):
  326. self.blocks.append(
  327. Residual(
  328. Attention(
  329. ed,
  330. kd,
  331. nh,
  332. attn_ratio=ar,
  333. activation=attention_activation,
  334. resolution=resolution, ),
  335. drop_path))
  336. if mr > 0:
  337. h = int(ed * mr)
  338. self.blocks.append(
  339. Residual(
  340. nn.Sequential(
  341. Linear_BN(ed, h),
  342. mlp_activation(),
  343. Linear_BN(
  344. h, ed, bn_weight_init=0), ),
  345. drop_path))
  346. if do[0] == 'Subsample':
  347. #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
  348. resolution_ = (resolution - 1) // do[5] + 1
  349. self.blocks.append(
  350. AttentionSubsample(
  351. *embed_dim[i:i + 2],
  352. key_dim=do[1],
  353. num_heads=do[2],
  354. attn_ratio=do[3],
  355. activation=attention_activation,
  356. stride=do[5],
  357. resolution=resolution,
  358. resolution_=resolution_))
  359. resolution = resolution_
  360. if do[4] > 0: # mlp_ratio
  361. h = int(embed_dim[i + 1] * do[4])
  362. self.blocks.append(
  363. Residual(
  364. nn.Sequential(
  365. Linear_BN(embed_dim[i + 1], h),
  366. mlp_activation(),
  367. Linear_BN(
  368. h, embed_dim[i + 1], bn_weight_init=0), ),
  369. drop_path))
  370. self.blocks = nn.Sequential(*self.blocks)
  371. # Classifier head
  372. self.head = BN_Linear(embed_dim[-1],
  373. class_num) if class_num > 0 else Identity()
  374. if distillation:
  375. self.head_dist = BN_Linear(
  376. embed_dim[-1], class_num) if class_num > 0 else Identity()
  377. def forward(self, x):
  378. x = self.patch_embed(x)
  379. x = x.flatten(2)
  380. x = paddle.transpose(x, perm=[0, 2, 1])
  381. x = self.blocks(x)
  382. x = x.mean(1)
  383. x = paddle.reshape(x, [-1, self.embed_dim[-1]])
  384. if self.distillation:
  385. x = self.head(x), self.head_dist(x)
  386. if not self.training:
  387. x = (x[0] + x[1]) / 2
  388. else:
  389. x = self.head(x)
  390. return x
  391. def model_factory(C, D, X, N, drop_path, class_num, distillation):
  392. embed_dim = [int(x) for x in C.split('_')]
  393. num_heads = [int(x) for x in N.split('_')]
  394. depth = [int(x) for x in X.split('_')]
  395. act = nn.Hardswish
  396. model = LeViT(
  397. patch_size=16,
  398. embed_dim=embed_dim,
  399. num_heads=num_heads,
  400. key_dim=[D] * 3,
  401. depth=depth,
  402. attn_ratio=[2, 2, 2],
  403. mlp_ratio=[2, 2, 2],
  404. down_ops=[
  405. #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
  406. ['Subsample', D, embed_dim[0] // D, 4, 2, 2],
  407. ['Subsample', D, embed_dim[1] // D, 4, 2, 2],
  408. ],
  409. attention_activation=act,
  410. mlp_activation=act,
  411. hybrid_backbone=b16(embed_dim[0], activation=act),
  412. class_num=class_num,
  413. drop_path=drop_path,
  414. distillation=distillation)
  415. return model
  416. specification = {
  417. 'LeViT_128S': {
  418. 'C': '128_256_384',
  419. 'D': 16,
  420. 'N': '4_6_8',
  421. 'X': '2_3_4',
  422. 'drop_path': 0
  423. },
  424. 'LeViT_128': {
  425. 'C': '128_256_384',
  426. 'D': 16,
  427. 'N': '4_8_12',
  428. 'X': '4_4_4',
  429. 'drop_path': 0
  430. },
  431. 'LeViT_192': {
  432. 'C': '192_288_384',
  433. 'D': 32,
  434. 'N': '3_5_6',
  435. 'X': '4_4_4',
  436. 'drop_path': 0
  437. },
  438. 'LeViT_256': {
  439. 'C': '256_384_512',
  440. 'D': 32,
  441. 'N': '4_6_8',
  442. 'X': '4_4_4',
  443. 'drop_path': 0
  444. },
  445. 'LeViT_384': {
  446. 'C': '384_512_768',
  447. 'D': 32,
  448. 'N': '6_9_12',
  449. 'X': '4_4_4',
  450. 'drop_path': 0.1
  451. },
  452. }
  453. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  454. if pretrained is False:
  455. pass
  456. elif pretrained is True:
  457. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  458. elif isinstance(pretrained, str):
  459. load_dygraph_pretrain(model, pretrained)
  460. else:
  461. raise RuntimeError(
  462. "pretrained type is not available. Please use `string` or `boolean` type."
  463. )
  464. def LeViT_128S(pretrained=False,
  465. use_ssld=False,
  466. class_num=1000,
  467. distillation=False,
  468. **kwargs):
  469. model = model_factory(
  470. **specification['LeViT_128S'],
  471. class_num=class_num,
  472. distillation=distillation)
  473. _load_pretrained(
  474. pretrained, model, MODEL_URLS["LeViT_128S"], use_ssld=use_ssld)
  475. return model
  476. def LeViT_128(pretrained=False,
  477. use_ssld=False,
  478. class_num=1000,
  479. distillation=False,
  480. **kwargs):
  481. model = model_factory(
  482. **specification['LeViT_128'],
  483. class_num=class_num,
  484. distillation=distillation)
  485. _load_pretrained(
  486. pretrained, model, MODEL_URLS["LeViT_128"], use_ssld=use_ssld)
  487. return model
  488. def LeViT_192(pretrained=False,
  489. use_ssld=False,
  490. class_num=1000,
  491. distillation=False,
  492. **kwargs):
  493. model = model_factory(
  494. **specification['LeViT_192'],
  495. class_num=class_num,
  496. distillation=distillation)
  497. _load_pretrained(
  498. pretrained, model, MODEL_URLS["LeViT_192"], use_ssld=use_ssld)
  499. return model
  500. def LeViT_256(pretrained=False,
  501. use_ssld=False,
  502. class_num=1000,
  503. distillation=False,
  504. **kwargs):
  505. model = model_factory(
  506. **specification['LeViT_256'],
  507. class_num=class_num,
  508. distillation=distillation)
  509. _load_pretrained(
  510. pretrained, model, MODEL_URLS["LeViT_256"], use_ssld=use_ssld)
  511. return model
  512. def LeViT_384(pretrained=False,
  513. use_ssld=False,
  514. class_num=1000,
  515. distillation=False,
  516. **kwargs):
  517. model = model_factory(
  518. **specification['LeViT_384'],
  519. class_num=class_num,
  520. distillation=distillation)
  521. _load_pretrained(
  522. pretrained, model, MODEL_URLS["LeViT_384"], use_ssld=use_ssld)
  523. return model