vision_transformer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  15. from collections.abc import Callable
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  20. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  21. MODEL_URLS = {
  22. "ViT_small_patch16_224":
  23. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams",
  24. "ViT_base_patch16_224":
  25. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams",
  26. "ViT_base_patch16_384":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams",
  28. "ViT_base_patch32_384":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams",
  30. "ViT_large_patch16_224":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams",
  32. "ViT_large_patch16_384":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams",
  34. "ViT_large_patch32_384":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams",
  36. "ViT_huge_patch16_224":
  37. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch16_224_pretrained.pdparams",
  38. "ViT_huge_patch32_384":
  39. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch32_384_pretrained.pdparams"
  40. }
  41. __all__ = list(MODEL_URLS.keys())
  42. trunc_normal_ = TruncatedNormal(std=.02)
  43. normal_ = Normal
  44. zeros_ = Constant(value=0.)
  45. ones_ = Constant(value=1.)
  46. def to_2tuple(x):
  47. return tuple([x] * 2)
  48. def drop_path(x, drop_prob=0., training=False):
  49. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  50. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  51. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  52. """
  53. if drop_prob == 0. or not training:
  54. return x
  55. keep_prob = paddle.to_tensor(1 - drop_prob)
  56. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  57. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  58. random_tensor = paddle.floor(random_tensor) # binarize
  59. output = x.divide(keep_prob) * random_tensor
  60. return output
  61. class DropPath(nn.Layer):
  62. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  63. """
  64. def __init__(self, drop_prob=None):
  65. super(DropPath, self).__init__()
  66. self.drop_prob = drop_prob
  67. def forward(self, x):
  68. return drop_path(x, self.drop_prob, self.training)
  69. class Identity(nn.Layer):
  70. def __init__(self):
  71. super(Identity, self).__init__()
  72. def forward(self, input):
  73. return input
  74. class Mlp(nn.Layer):
  75. def __init__(self,
  76. in_features,
  77. hidden_features=None,
  78. out_features=None,
  79. act_layer=nn.GELU,
  80. drop=0.):
  81. super().__init__()
  82. out_features = out_features or in_features
  83. hidden_features = hidden_features or in_features
  84. self.fc1 = nn.Linear(in_features, hidden_features)
  85. self.act = act_layer()
  86. self.fc2 = nn.Linear(hidden_features, out_features)
  87. self.drop = nn.Dropout(drop)
  88. def forward(self, x):
  89. x = self.fc1(x)
  90. x = self.act(x)
  91. x = self.drop(x)
  92. x = self.fc2(x)
  93. x = self.drop(x)
  94. return x
  95. class Attention(nn.Layer):
  96. def __init__(self,
  97. dim,
  98. num_heads=8,
  99. qkv_bias=False,
  100. qk_scale=None,
  101. attn_drop=0.,
  102. proj_drop=0.):
  103. super().__init__()
  104. self.num_heads = num_heads
  105. head_dim = dim // num_heads
  106. self.scale = qk_scale or head_dim**-0.5
  107. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  108. self.attn_drop = nn.Dropout(attn_drop)
  109. self.proj = nn.Linear(dim, dim)
  110. self.proj_drop = nn.Dropout(proj_drop)
  111. def forward(self, x):
  112. # B= paddle.shape(x)[0]
  113. N, C = x.shape[1:]
  114. qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
  115. self.num_heads)).transpose((2, 0, 3, 1, 4))
  116. q, k, v = qkv[0], qkv[1], qkv[2]
  117. attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
  118. attn = nn.functional.softmax(attn, axis=-1)
  119. attn = self.attn_drop(attn)
  120. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
  121. x = self.proj(x)
  122. x = self.proj_drop(x)
  123. return x
  124. class Block(nn.Layer):
  125. def __init__(self,
  126. dim,
  127. num_heads,
  128. mlp_ratio=4.,
  129. qkv_bias=False,
  130. qk_scale=None,
  131. drop=0.,
  132. attn_drop=0.,
  133. drop_path=0.,
  134. act_layer=nn.GELU,
  135. norm_layer='nn.LayerNorm',
  136. epsilon=1e-5):
  137. super().__init__()
  138. if isinstance(norm_layer, str):
  139. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  140. elif isinstance(norm_layer, Callable):
  141. self.norm1 = norm_layer(dim)
  142. else:
  143. raise TypeError(
  144. "The norm_layer must be str or paddle.nn.layer.Layer class")
  145. self.attn = Attention(
  146. dim,
  147. num_heads=num_heads,
  148. qkv_bias=qkv_bias,
  149. qk_scale=qk_scale,
  150. attn_drop=attn_drop,
  151. proj_drop=drop)
  152. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  153. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  154. if isinstance(norm_layer, str):
  155. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  156. elif isinstance(norm_layer, Callable):
  157. self.norm2 = norm_layer(dim)
  158. else:
  159. raise TypeError(
  160. "The norm_layer must be str or paddle.nn.layer.Layer class")
  161. mlp_hidden_dim = int(dim * mlp_ratio)
  162. self.mlp = Mlp(in_features=dim,
  163. hidden_features=mlp_hidden_dim,
  164. act_layer=act_layer,
  165. drop=drop)
  166. def forward(self, x):
  167. x = x + self.drop_path(self.attn(self.norm1(x)))
  168. x = x + self.drop_path(self.mlp(self.norm2(x)))
  169. return x
  170. class PatchEmbed(nn.Layer):
  171. """ Image to Patch Embedding
  172. """
  173. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  174. super().__init__()
  175. img_size = to_2tuple(img_size)
  176. patch_size = to_2tuple(patch_size)
  177. num_patches = (img_size[1] // patch_size[1]) * \
  178. (img_size[0] // patch_size[0])
  179. self.img_size = img_size
  180. self.patch_size = patch_size
  181. self.num_patches = num_patches
  182. self.proj = nn.Conv2D(
  183. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  184. def forward(self, x):
  185. B, C, H, W = x.shape
  186. assert H == self.img_size[0] and W == self.img_size[1], \
  187. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  188. x = self.proj(x).flatten(2).transpose((0, 2, 1))
  189. return x
  190. class VisionTransformer(nn.Layer):
  191. """ Vision Transformer with support for patch input
  192. """
  193. def __init__(self,
  194. img_size=224,
  195. patch_size=16,
  196. in_chans=3,
  197. class_num=1000,
  198. embed_dim=768,
  199. depth=12,
  200. num_heads=12,
  201. mlp_ratio=4,
  202. qkv_bias=False,
  203. qk_scale=None,
  204. drop_rate=0.,
  205. attn_drop_rate=0.,
  206. drop_path_rate=0.,
  207. norm_layer='nn.LayerNorm',
  208. epsilon=1e-5,
  209. **kwargs):
  210. super().__init__()
  211. self.class_num = class_num
  212. self.num_features = self.embed_dim = embed_dim
  213. self.patch_embed = PatchEmbed(
  214. img_size=img_size,
  215. patch_size=patch_size,
  216. in_chans=in_chans,
  217. embed_dim=embed_dim)
  218. num_patches = self.patch_embed.num_patches
  219. self.pos_embed = self.create_parameter(
  220. shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
  221. self.add_parameter("pos_embed", self.pos_embed)
  222. self.cls_token = self.create_parameter(
  223. shape=(1, 1, embed_dim), default_initializer=zeros_)
  224. self.add_parameter("cls_token", self.cls_token)
  225. self.pos_drop = nn.Dropout(p=drop_rate)
  226. dpr = np.linspace(0, drop_path_rate, depth)
  227. self.blocks = nn.LayerList([
  228. Block(
  229. dim=embed_dim,
  230. num_heads=num_heads,
  231. mlp_ratio=mlp_ratio,
  232. qkv_bias=qkv_bias,
  233. qk_scale=qk_scale,
  234. drop=drop_rate,
  235. attn_drop=attn_drop_rate,
  236. drop_path=dpr[i],
  237. norm_layer=norm_layer,
  238. epsilon=epsilon) for i in range(depth)
  239. ])
  240. self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
  241. # Classifier head
  242. self.head = nn.Linear(embed_dim,
  243. class_num) if class_num > 0 else Identity()
  244. trunc_normal_(self.pos_embed)
  245. trunc_normal_(self.cls_token)
  246. self.apply(self._init_weights)
  247. def _init_weights(self, m):
  248. if isinstance(m, nn.Linear):
  249. trunc_normal_(m.weight)
  250. if isinstance(m, nn.Linear) and m.bias is not None:
  251. zeros_(m.bias)
  252. elif isinstance(m, nn.LayerNorm):
  253. zeros_(m.bias)
  254. ones_(m.weight)
  255. def forward_features(self, x):
  256. # B = x.shape[0]
  257. B = paddle.shape(x)[0]
  258. x = self.patch_embed(x)
  259. cls_tokens = self.cls_token.expand((B, -1, -1))
  260. x = paddle.concat((cls_tokens, x), axis=1)
  261. x = x + self.pos_embed
  262. x = self.pos_drop(x)
  263. for blk in self.blocks:
  264. x = blk(x)
  265. x = self.norm(x)
  266. return x[:, 0]
  267. def forward(self, x):
  268. x = self.forward_features(x)
  269. x = self.head(x)
  270. return x
  271. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  272. if pretrained is False:
  273. pass
  274. elif pretrained is True:
  275. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  276. elif isinstance(pretrained, str):
  277. load_dygraph_pretrain(model, pretrained)
  278. else:
  279. raise RuntimeError(
  280. "pretrained type is not available. Please use `string` or `boolean` type."
  281. )
  282. def ViT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs):
  283. model = VisionTransformer(
  284. patch_size=16,
  285. embed_dim=768,
  286. depth=8,
  287. num_heads=8,
  288. mlp_ratio=3,
  289. qk_scale=768**-0.5,
  290. **kwargs)
  291. _load_pretrained(
  292. pretrained,
  293. model,
  294. MODEL_URLS["ViT_small_patch16_224"],
  295. use_ssld=use_ssld)
  296. return model
  297. def ViT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
  298. model = VisionTransformer(
  299. patch_size=16,
  300. embed_dim=768,
  301. depth=12,
  302. num_heads=12,
  303. mlp_ratio=4,
  304. qkv_bias=True,
  305. epsilon=1e-6,
  306. **kwargs)
  307. _load_pretrained(
  308. pretrained,
  309. model,
  310. MODEL_URLS["ViT_base_patch16_224"],
  311. use_ssld=use_ssld)
  312. return model
  313. def ViT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs):
  314. model = VisionTransformer(
  315. img_size=384,
  316. patch_size=16,
  317. embed_dim=768,
  318. depth=12,
  319. num_heads=12,
  320. mlp_ratio=4,
  321. qkv_bias=True,
  322. epsilon=1e-6,
  323. **kwargs)
  324. _load_pretrained(
  325. pretrained,
  326. model,
  327. MODEL_URLS["ViT_base_patch16_384"],
  328. use_ssld=use_ssld)
  329. return model
  330. def ViT_base_patch32_384(pretrained=False, use_ssld=False, **kwargs):
  331. model = VisionTransformer(
  332. img_size=384,
  333. patch_size=32,
  334. embed_dim=768,
  335. depth=12,
  336. num_heads=12,
  337. mlp_ratio=4,
  338. qkv_bias=True,
  339. epsilon=1e-6,
  340. **kwargs)
  341. _load_pretrained(
  342. pretrained,
  343. model,
  344. MODEL_URLS["ViT_base_patch32_384"],
  345. use_ssld=use_ssld)
  346. return model
  347. def ViT_large_patch16_224(pretrained=False, use_ssld=False, **kwargs):
  348. model = VisionTransformer(
  349. patch_size=16,
  350. embed_dim=1024,
  351. depth=24,
  352. num_heads=16,
  353. mlp_ratio=4,
  354. qkv_bias=True,
  355. epsilon=1e-6,
  356. **kwargs)
  357. _load_pretrained(
  358. pretrained,
  359. model,
  360. MODEL_URLS["ViT_large_patch16_224"],
  361. use_ssld=use_ssld)
  362. return model
  363. def ViT_large_patch16_384(pretrained=False, use_ssld=False, **kwargs):
  364. model = VisionTransformer(
  365. img_size=384,
  366. patch_size=16,
  367. embed_dim=1024,
  368. depth=24,
  369. num_heads=16,
  370. mlp_ratio=4,
  371. qkv_bias=True,
  372. epsilon=1e-6,
  373. **kwargs)
  374. _load_pretrained(
  375. pretrained,
  376. model,
  377. MODEL_URLS["ViT_large_patch16_384"],
  378. use_ssld=use_ssld)
  379. return model
  380. def ViT_large_patch32_384(pretrained=False, use_ssld=False, **kwargs):
  381. model = VisionTransformer(
  382. img_size=384,
  383. patch_size=32,
  384. embed_dim=1024,
  385. depth=24,
  386. num_heads=16,
  387. mlp_ratio=4,
  388. qkv_bias=True,
  389. epsilon=1e-6,
  390. **kwargs)
  391. _load_pretrained(
  392. pretrained,
  393. model,
  394. MODEL_URLS["ViT_large_patch32_384"],
  395. use_ssld=use_ssld)
  396. return model
  397. def ViT_huge_patch16_224(pretrained=False, use_ssld=False, **kwargs):
  398. model = VisionTransformer(
  399. patch_size=16,
  400. embed_dim=1280,
  401. depth=32,
  402. num_heads=16,
  403. mlp_ratio=4,
  404. **kwargs)
  405. _load_pretrained(
  406. pretrained,
  407. model,
  408. MODEL_URLS["ViT_huge_patch16_224"],
  409. use_ssld=use_ssld)
  410. return model
  411. def ViT_huge_patch32_384(pretrained=False, use_ssld=False, **kwargs):
  412. model = VisionTransformer(
  413. img_size=384,
  414. patch_size=32,
  415. embed_dim=1280,
  416. depth=32,
  417. num_heads=16,
  418. mlp_ratio=4,
  419. **kwargs)
  420. _load_pretrained(
  421. pretrained,
  422. model,
  423. MODEL_URLS["ViT_huge_patch32_384"],
  424. use_ssld=use_ssld)
  425. return model