swin_transformer.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/microsoft/Swin-Transformer
  15. import numpy as np
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn.initializer import TruncatedNormal, Constant
  20. from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity
  21. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  22. MODEL_URLS = {
  23. "SwinTransformer_tiny_patch4_window7_224":
  24. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams",
  25. "SwinTransformer_small_patch4_window7_224":
  26. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_small_patch4_window7_224_pretrained.pdparams",
  27. "SwinTransformer_base_patch4_window7_224":
  28. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window7_224_pretrained.pdparams",
  29. "SwinTransformer_base_patch4_window12_384":
  30. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams",
  31. "SwinTransformer_large_patch4_window7_224":
  32. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams",
  33. "SwinTransformer_large_patch4_window12_384":
  34. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams",
  35. }
  36. __all__ = list(MODEL_URLS.keys())
  37. class Mlp(nn.Layer):
  38. def __init__(self,
  39. in_features,
  40. hidden_features=None,
  41. out_features=None,
  42. act_layer=nn.GELU,
  43. drop=0.):
  44. super().__init__()
  45. out_features = out_features or in_features
  46. hidden_features = hidden_features or in_features
  47. self.fc1 = nn.Linear(in_features, hidden_features)
  48. self.act = act_layer()
  49. self.fc2 = nn.Linear(hidden_features, out_features)
  50. self.drop = nn.Dropout(drop)
  51. def forward(self, x):
  52. x = self.fc1(x)
  53. x = self.act(x)
  54. x = self.drop(x)
  55. x = self.fc2(x)
  56. x = self.drop(x)
  57. return x
  58. def window_partition(x, window_size):
  59. """
  60. Args:
  61. x: (B, H, W, C)
  62. window_size (int): window size
  63. Returns:
  64. windows: (num_windows*B, window_size, window_size, C)
  65. """
  66. B, H, W, C = x.shape
  67. x = x.reshape(
  68. [B, H // window_size, window_size, W // window_size, window_size, C])
  69. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
  70. [-1, window_size, window_size, C])
  71. return windows
  72. def window_reverse(windows, window_size, H, W, C):
  73. """
  74. Args:
  75. windows: (num_windows*B, window_size, window_size, C)
  76. window_size (int): Window size
  77. H (int): Height of image
  78. W (int): Width of image
  79. Returns:
  80. x: (B, H, W, C)
  81. """
  82. x = windows.reshape(
  83. [-1, H // window_size, W // window_size, window_size, window_size, C])
  84. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H, W, C])
  85. return x
  86. class WindowAttention(nn.Layer):
  87. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  88. It supports both of shifted and non-shifted window.
  89. Args:
  90. dim (int): Number of input channels.
  91. window_size (tuple[int]): The height and width of the window.
  92. num_heads (int): Number of attention heads.
  93. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  94. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  95. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  96. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  97. """
  98. def __init__(self,
  99. dim,
  100. window_size,
  101. num_heads,
  102. qkv_bias=True,
  103. qk_scale=None,
  104. attn_drop=0.,
  105. proj_drop=0.):
  106. super().__init__()
  107. self.dim = dim
  108. self.window_size = window_size # Wh, Ww
  109. self.num_heads = num_heads
  110. head_dim = dim // num_heads
  111. self.scale = qk_scale or head_dim**-0.5
  112. # define a parameter table of relative position bias
  113. # 2*Wh-1 * 2*Ww-1, nH
  114. self.relative_position_bias_table = self.create_parameter(
  115. shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  116. num_heads),
  117. default_initializer=zeros_)
  118. self.add_parameter("relative_position_bias_table",
  119. self.relative_position_bias_table)
  120. # get pair-wise relative position index for each token inside the window
  121. coords_h = paddle.arange(self.window_size[0])
  122. coords_w = paddle.arange(self.window_size[1])
  123. coords = paddle.stack(paddle.meshgrid(
  124. [coords_h, coords_w])) # 2, Wh, Ww
  125. coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
  126. coords_flatten_1 = coords_flatten.unsqueeze(axis=2)
  127. coords_flatten_2 = coords_flatten.unsqueeze(axis=1)
  128. relative_coords = coords_flatten_1 - coords_flatten_2
  129. relative_coords = relative_coords.transpose(
  130. [1, 2, 0]) # Wh*Ww, Wh*Ww, 2
  131. relative_coords[:, :, 0] += self.window_size[
  132. 0] - 1 # shift to start from 0
  133. relative_coords[:, :, 1] += self.window_size[1] - 1
  134. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  135. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  136. self.register_buffer("relative_position_index",
  137. relative_position_index)
  138. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  139. self.attn_drop = nn.Dropout(attn_drop)
  140. self.proj = nn.Linear(dim, dim)
  141. self.proj_drop = nn.Dropout(proj_drop)
  142. trunc_normal_(self.relative_position_bias_table)
  143. self.softmax = nn.Softmax(axis=-1)
  144. def forward(self, x, mask=None):
  145. """
  146. Args:
  147. x: input features with shape of (num_windows*B, N, C)
  148. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  149. """
  150. B_, N, C = x.shape
  151. qkv = self.qkv(x).reshape(
  152. [B_, N, 3, self.num_heads, C // self.num_heads]).transpose(
  153. [2, 0, 3, 1, 4])
  154. q, k, v = qkv[0], qkv[1], qkv[2]
  155. q = q * self.scale
  156. attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
  157. index = self.relative_position_index.reshape([-1])
  158. relative_position_bias = paddle.index_select(
  159. self.relative_position_bias_table, index)
  160. relative_position_bias = relative_position_bias.reshape([
  161. self.window_size[0] * self.window_size[1],
  162. self.window_size[0] * self.window_size[1], -1
  163. ]) # Wh*Ww,Wh*Ww,nH
  164. relative_position_bias = relative_position_bias.transpose(
  165. [2, 0, 1]) # nH, Wh*Ww, Wh*Ww
  166. attn = attn + relative_position_bias.unsqueeze(0)
  167. if mask is not None:
  168. nW = mask.shape[0]
  169. attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N
  170. ]) + mask.unsqueeze(1).unsqueeze(0)
  171. attn = attn.reshape([-1, self.num_heads, N, N])
  172. attn = self.softmax(attn)
  173. else:
  174. attn = self.softmax(attn)
  175. attn = self.attn_drop(attn)
  176. # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
  177. x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C])
  178. x = self.proj(x)
  179. x = self.proj_drop(x)
  180. return x
  181. def extra_repr(self):
  182. return "dim={}, window_size={}, num_heads={}".format(
  183. self.dim, self.window_size, self.num_heads)
  184. def flops(self, N):
  185. # calculate flops for 1 window with token length of N
  186. flops = 0
  187. # qkv = self.qkv(x)
  188. flops += N * self.dim * 3 * self.dim
  189. # attn = (q @ k.transpose(-2, -1))
  190. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  191. # x = (attn @ v)
  192. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  193. # x = self.proj(x)
  194. flops += N * self.dim * self.dim
  195. return flops
  196. class SwinTransformerBlock(nn.Layer):
  197. r""" Swin Transformer Block.
  198. Args:
  199. dim (int): Number of input channels.
  200. input_resolution (tuple[int]): Input resulotion.
  201. num_heads (int): Number of attention heads.
  202. window_size (int): Window size.
  203. shift_size (int): Shift size for SW-MSA.
  204. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  205. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  206. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  207. drop (float, optional): Dropout rate. Default: 0.0
  208. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  209. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  210. act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
  211. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  212. """
  213. def __init__(self,
  214. dim,
  215. input_resolution,
  216. num_heads,
  217. window_size=7,
  218. shift_size=0,
  219. mlp_ratio=4.,
  220. qkv_bias=True,
  221. qk_scale=None,
  222. drop=0.,
  223. attn_drop=0.,
  224. drop_path=0.,
  225. act_layer=nn.GELU,
  226. norm_layer=nn.LayerNorm):
  227. super().__init__()
  228. self.dim = dim
  229. self.input_resolution = input_resolution
  230. self.num_heads = num_heads
  231. self.window_size = window_size
  232. self.shift_size = shift_size
  233. self.mlp_ratio = mlp_ratio
  234. if min(self.input_resolution) <= self.window_size:
  235. # if window size is larger than input resolution, we don't partition windows
  236. self.shift_size = 0
  237. self.window_size = min(self.input_resolution)
  238. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  239. self.norm1 = norm_layer(dim)
  240. self.attn = WindowAttention(
  241. dim,
  242. window_size=to_2tuple(self.window_size),
  243. num_heads=num_heads,
  244. qkv_bias=qkv_bias,
  245. qk_scale=qk_scale,
  246. attn_drop=attn_drop,
  247. proj_drop=drop)
  248. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  249. self.norm2 = norm_layer(dim)
  250. mlp_hidden_dim = int(dim * mlp_ratio)
  251. self.mlp = Mlp(in_features=dim,
  252. hidden_features=mlp_hidden_dim,
  253. act_layer=act_layer,
  254. drop=drop)
  255. if self.shift_size > 0:
  256. # calculate attention mask for SW-MSA
  257. H, W = self.input_resolution
  258. img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1
  259. h_slices = (slice(0, -self.window_size),
  260. slice(-self.window_size, -self.shift_size),
  261. slice(-self.shift_size, None))
  262. w_slices = (slice(0, -self.window_size),
  263. slice(-self.window_size, -self.shift_size),
  264. slice(-self.shift_size, None))
  265. cnt = 0
  266. for h in h_slices:
  267. for w in w_slices:
  268. img_mask[:, h, w, :] = cnt
  269. cnt += 1
  270. mask_windows = window_partition(
  271. img_mask, self.window_size) # nW, window_size, window_size, 1
  272. mask_windows = mask_windows.reshape(
  273. [-1, self.window_size * self.window_size])
  274. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  275. huns = -100.0 * paddle.ones_like(attn_mask)
  276. attn_mask = huns * (attn_mask != 0).astype("float32")
  277. else:
  278. attn_mask = None
  279. self.register_buffer("attn_mask", attn_mask)
  280. def forward(self, x):
  281. H, W = self.input_resolution
  282. B, L, C = x.shape
  283. assert L == H * W, "input feature has wrong size"
  284. shortcut = x
  285. x = self.norm1(x)
  286. x = x.reshape([B, H, W, C])
  287. # cyclic shift
  288. if self.shift_size > 0:
  289. shifted_x = paddle.roll(
  290. x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
  291. else:
  292. shifted_x = x
  293. # partition windows
  294. x_windows = window_partition(
  295. shifted_x, self.window_size) # nW*B, window_size, window_size, C
  296. x_windows = x_windows.reshape(
  297. [-1, self.window_size * self.window_size,
  298. C]) # nW*B, window_size*window_size, C
  299. # W-MSA/SW-MSA
  300. attn_windows = self.attn(
  301. x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
  302. # merge windows
  303. attn_windows = attn_windows.reshape(
  304. [-1, self.window_size, self.window_size, C])
  305. shifted_x = window_reverse(attn_windows, self.window_size, H, W,
  306. C) # B H' W' C
  307. # reverse cyclic shift
  308. if self.shift_size > 0:
  309. x = paddle.roll(
  310. shifted_x,
  311. shifts=(self.shift_size, self.shift_size),
  312. axis=(1, 2))
  313. else:
  314. x = shifted_x
  315. x = x.reshape([B, H * W, C])
  316. # FFN
  317. x = shortcut + self.drop_path(x)
  318. x = x + self.drop_path(self.mlp(self.norm2(x)))
  319. return x
  320. def extra_repr(self):
  321. return "dim={}, input_resolution={}, num_heads={}, window_size={}, shift_size={}, mlp_ratio={}".format(
  322. self.dim, self.input_resolution, self.num_heads, self.window_size,
  323. self.shift_size, self.mlp_ratio)
  324. def flops(self):
  325. flops = 0
  326. H, W = self.input_resolution
  327. # norm1
  328. flops += self.dim * H * W
  329. # W-MSA/SW-MSA
  330. nW = H * W / self.window_size / self.window_size
  331. flops += nW * self.attn.flops(self.window_size * self.window_size)
  332. # mlp
  333. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  334. # norm2
  335. flops += self.dim * H * W
  336. return flops
  337. class PatchMerging(nn.Layer):
  338. r""" Patch Merging Layer.
  339. Args:
  340. input_resolution (tuple[int]): Resolution of input feature.
  341. dim (int): Number of input channels.
  342. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  343. """
  344. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  345. super().__init__()
  346. self.input_resolution = input_resolution
  347. self.dim = dim
  348. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  349. self.norm = norm_layer(4 * dim)
  350. def forward(self, x):
  351. """
  352. x: B, H*W, C
  353. """
  354. H, W = self.input_resolution
  355. B, L, C = x.shape
  356. assert L == H * W, "input feature has wrong size"
  357. assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format(
  358. H, W)
  359. x = x.reshape([B, H, W, C])
  360. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  361. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  362. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  363. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  364. x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  365. x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
  366. x = self.norm(x)
  367. x = self.reduction(x)
  368. return x
  369. def extra_repr(self):
  370. return "input_resolution={}, dim={}".format(self.input_resolution,
  371. self.dim)
  372. def flops(self):
  373. H, W = self.input_resolution
  374. flops = H * W * self.dim
  375. flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  376. return flops
  377. class BasicLayer(nn.Layer):
  378. """ A basic Swin Transformer layer for one stage.
  379. Args:
  380. dim (int): Number of input channels.
  381. input_resolution (tuple[int]): Input resolution.
  382. depth (int): Number of blocks.
  383. num_heads (int): Number of attention heads.
  384. window_size (int): Local window size.
  385. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  386. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  387. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  388. drop (float, optional): Dropout rate. Default: 0.0
  389. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  390. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  391. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  392. downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
  393. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  394. """
  395. def __init__(self,
  396. dim,
  397. input_resolution,
  398. depth,
  399. num_heads,
  400. window_size,
  401. mlp_ratio=4.,
  402. qkv_bias=True,
  403. qk_scale=None,
  404. drop=0.,
  405. attn_drop=0.,
  406. drop_path=0.,
  407. norm_layer=nn.LayerNorm,
  408. downsample=None,
  409. use_checkpoint=False):
  410. super().__init__()
  411. self.dim = dim
  412. self.input_resolution = input_resolution
  413. self.depth = depth
  414. self.use_checkpoint = use_checkpoint
  415. # build blocks
  416. self.blocks = nn.LayerList([
  417. SwinTransformerBlock(
  418. dim=dim,
  419. input_resolution=input_resolution,
  420. num_heads=num_heads,
  421. window_size=window_size,
  422. shift_size=0 if (i % 2 == 0) else window_size // 2,
  423. mlp_ratio=mlp_ratio,
  424. qkv_bias=qkv_bias,
  425. qk_scale=qk_scale,
  426. drop=drop,
  427. attn_drop=attn_drop,
  428. drop_path=drop_path[i]
  429. if isinstance(drop_path, list) else drop_path,
  430. norm_layer=norm_layer) for i in range(depth)
  431. ])
  432. # patch merging layer
  433. if downsample is not None:
  434. self.downsample = downsample(
  435. input_resolution, dim=dim, norm_layer=norm_layer)
  436. else:
  437. self.downsample = None
  438. def forward(self, x):
  439. for blk in self.blocks:
  440. x = blk(x)
  441. if self.downsample is not None:
  442. x = self.downsample(x)
  443. return x
  444. def extra_repr(self):
  445. return "dim={}, input_resolution={}, depth={}".format(
  446. self.dim, self.input_resolution, self.depth)
  447. def flops(self):
  448. flops = 0
  449. for blk in self.blocks:
  450. flops += blk.flops()
  451. if self.downsample is not None:
  452. flops += self.downsample.flops()
  453. return flops
  454. class PatchEmbed(nn.Layer):
  455. """ Image to Patch Embedding
  456. Args:
  457. img_size (int): Image size. Default: 224.
  458. patch_size (int): Patch token size. Default: 4.
  459. in_chans (int): Number of input image channels. Default: 3.
  460. embed_dim (int): Number of linear projection output channels. Default: 96.
  461. norm_layer (nn.Layer, optional): Normalization layer. Default: None
  462. """
  463. def __init__(self,
  464. img_size=224,
  465. patch_size=4,
  466. in_chans=3,
  467. embed_dim=96,
  468. norm_layer=None):
  469. super().__init__()
  470. img_size = to_2tuple(img_size)
  471. patch_size = to_2tuple(patch_size)
  472. patches_resolution = [
  473. img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  474. ]
  475. self.img_size = img_size
  476. self.patch_size = patch_size
  477. self.patches_resolution = patches_resolution
  478. self.num_patches = patches_resolution[0] * patches_resolution[1]
  479. self.in_chans = in_chans
  480. self.embed_dim = embed_dim
  481. self.proj = nn.Conv2D(
  482. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  483. if norm_layer is not None:
  484. self.norm = norm_layer(embed_dim)
  485. else:
  486. self.norm = None
  487. def forward(self, x):
  488. B, C, H, W = x.shape
  489. # TODO (littletomatodonkey), uncomment the line will cause failure of jit.save
  490. # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
  491. x = self.proj(x)
  492. x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
  493. if self.norm is not None:
  494. x = self.norm(x)
  495. return x
  496. def flops(self):
  497. Ho, Wo = self.patches_resolution
  498. flops = Ho * Wo * self.embed_dim * self.in_chans * (
  499. self.patch_size[0] * self.patch_size[1])
  500. if self.norm is not None:
  501. flops += Ho * Wo * self.embed_dim
  502. return flops
  503. class SwinTransformer(nn.Layer):
  504. """ Swin Transformer
  505. A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  506. https://arxiv.org/pdf/2103.14030
  507. Args:
  508. img_size (int | tuple(int)): Input image size. Default 224
  509. patch_size (int | tuple(int)): Patch size. Default: 4
  510. in_chans (int): Number of input image channels. Default: 3
  511. num_classes (int): Number of classes for classification head. Default: 1000
  512. embed_dim (int): Patch embedding dimension. Default: 96
  513. depths (tuple(int)): Depth of each Swin Transformer layer.
  514. num_heads (tuple(int)): Number of attention heads in different layers.
  515. window_size (int): Window size. Default: 7
  516. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  517. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  518. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  519. drop_rate (float): Dropout rate. Default: 0
  520. attn_drop_rate (float): Attention dropout rate. Default: 0
  521. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  522. norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
  523. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  524. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  525. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  526. """
  527. def __init__(self,
  528. img_size=224,
  529. patch_size=4,
  530. in_chans=3,
  531. class_num=1000,
  532. embed_dim=96,
  533. depths=[2, 2, 6, 2],
  534. num_heads=[3, 6, 12, 24],
  535. window_size=7,
  536. mlp_ratio=4.,
  537. qkv_bias=True,
  538. qk_scale=None,
  539. drop_rate=0.,
  540. attn_drop_rate=0.,
  541. drop_path_rate=0.1,
  542. norm_layer=nn.LayerNorm,
  543. ape=False,
  544. patch_norm=True,
  545. use_checkpoint=False,
  546. **kwargs):
  547. super(SwinTransformer, self).__init__()
  548. self.num_classes = num_classes = class_num
  549. self.num_layers = len(depths)
  550. self.embed_dim = embed_dim
  551. self.ape = ape
  552. self.patch_norm = patch_norm
  553. self.num_features = int(embed_dim * 2**(self.num_layers - 1))
  554. self.mlp_ratio = mlp_ratio
  555. # split image into non-overlapping patches
  556. self.patch_embed = PatchEmbed(
  557. img_size=img_size,
  558. patch_size=patch_size,
  559. in_chans=in_chans,
  560. embed_dim=embed_dim,
  561. norm_layer=norm_layer if self.patch_norm else None)
  562. num_patches = self.patch_embed.num_patches
  563. patches_resolution = self.patch_embed.patches_resolution
  564. self.patches_resolution = patches_resolution
  565. # absolute position embedding
  566. if self.ape:
  567. self.absolute_pos_embed = self.create_parameter(
  568. shape=(1, num_patches, embed_dim), default_initializer=zeros_)
  569. self.add_parameter("absolute_pos_embed", self.absolute_pos_embed)
  570. trunc_normal_(self.absolute_pos_embed)
  571. self.pos_drop = nn.Dropout(p=drop_rate)
  572. # stochastic depth
  573. dpr = np.linspace(0, drop_path_rate,
  574. sum(depths)).tolist() # stochastic depth decay rule
  575. # build layers
  576. self.layers = nn.LayerList()
  577. for i_layer in range(self.num_layers):
  578. layer = BasicLayer(
  579. dim=int(embed_dim * 2**i_layer),
  580. input_resolution=(patches_resolution[0] // (2**i_layer),
  581. patches_resolution[1] // (2**i_layer)),
  582. depth=depths[i_layer],
  583. num_heads=num_heads[i_layer],
  584. window_size=window_size,
  585. mlp_ratio=self.mlp_ratio,
  586. qkv_bias=qkv_bias,
  587. qk_scale=qk_scale,
  588. drop=drop_rate,
  589. attn_drop=attn_drop_rate,
  590. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  591. norm_layer=norm_layer,
  592. downsample=PatchMerging
  593. if (i_layer < self.num_layers - 1) else None,
  594. use_checkpoint=use_checkpoint)
  595. self.layers.append(layer)
  596. self.norm = norm_layer(self.num_features)
  597. self.avgpool = nn.AdaptiveAvgPool1D(1)
  598. self.head = nn.Linear(
  599. self.num_features,
  600. num_classes) if self.num_classes > 0 else nn.Identity()
  601. self.apply(self._init_weights)
  602. def _init_weights(self, m):
  603. if isinstance(m, nn.Linear):
  604. trunc_normal_(m.weight)
  605. if isinstance(m, nn.Linear) and m.bias is not None:
  606. zeros_(m.bias)
  607. elif isinstance(m, nn.LayerNorm):
  608. zeros_(m.bias)
  609. ones_(m.weight)
  610. def forward_features(self, x):
  611. x = self.patch_embed(x)
  612. if self.ape:
  613. x = x + self.absolute_pos_embed
  614. x = self.pos_drop(x)
  615. for layer in self.layers:
  616. x = layer(x)
  617. x = self.norm(x) # B L C
  618. x = self.avgpool(x.transpose([0, 2, 1])) # B C 1
  619. x = paddle.flatten(x, 1)
  620. return x
  621. def forward(self, x):
  622. x = self.forward_features(x)
  623. x = self.head(x)
  624. return x
  625. def flops(self):
  626. flops = 0
  627. flops += self.patch_embed.flops()
  628. for _, layer in enumerate(self.layers):
  629. flops += layer.flops()
  630. flops += self.num_features * self.patches_resolution[
  631. 0] * self.patches_resolution[1] // (2**self.num_layers)
  632. flops += self.num_features * self.num_classes
  633. return flops
  634. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  635. if pretrained is False:
  636. pass
  637. elif pretrained is True:
  638. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  639. elif isinstance(pretrained, str):
  640. load_dygraph_pretrain(model, pretrained)
  641. else:
  642. raise RuntimeError(
  643. "pretrained type is not available. Please use `string` or `boolean` type."
  644. )
  645. def SwinTransformer_tiny_patch4_window7_224(pretrained=False,
  646. use_ssld=False,
  647. **kwargs):
  648. model = SwinTransformer(
  649. embed_dim=96,
  650. depths=[2, 2, 6, 2],
  651. num_heads=[3, 6, 12, 24],
  652. window_size=7,
  653. drop_path_rate=0.2,
  654. **kwargs)
  655. _load_pretrained(
  656. pretrained,
  657. model,
  658. MODEL_URLS["SwinTransformer_tiny_patch4_window7_224"],
  659. use_ssld=use_ssld)
  660. return model
  661. def SwinTransformer_small_patch4_window7_224(pretrained=False,
  662. use_ssld=False,
  663. **kwargs):
  664. model = SwinTransformer(
  665. embed_dim=96,
  666. depths=[2, 2, 18, 2],
  667. num_heads=[3, 6, 12, 24],
  668. window_size=7,
  669. **kwargs)
  670. _load_pretrained(
  671. pretrained,
  672. model,
  673. MODEL_URLS["SwinTransformer_small_patch4_window7_224"],
  674. use_ssld=use_ssld)
  675. return model
  676. def SwinTransformer_base_patch4_window7_224(pretrained=False,
  677. use_ssld=False,
  678. **kwargs):
  679. model = SwinTransformer(
  680. embed_dim=128,
  681. depths=[2, 2, 18, 2],
  682. num_heads=[4, 8, 16, 32],
  683. window_size=7,
  684. drop_path_rate=0.5,
  685. **kwargs)
  686. _load_pretrained(
  687. pretrained,
  688. model,
  689. MODEL_URLS["SwinTransformer_base_patch4_window7_224"],
  690. use_ssld=use_ssld)
  691. return model
  692. def SwinTransformer_base_patch4_window12_384(pretrained=False,
  693. use_ssld=False,
  694. **kwargs):
  695. model = SwinTransformer(
  696. img_size=384,
  697. embed_dim=128,
  698. depths=[2, 2, 18, 2],
  699. num_heads=[4, 8, 16, 32],
  700. window_size=12,
  701. drop_path_rate=0.5, # NOTE: do not appear in offical code
  702. **kwargs)
  703. _load_pretrained(
  704. pretrained,
  705. model,
  706. MODEL_URLS["SwinTransformer_base_patch4_window12_384"],
  707. use_ssld=use_ssld)
  708. return model
  709. def SwinTransformer_large_patch4_window7_224(pretrained=False,
  710. use_ssld=False,
  711. **kwargs):
  712. model = SwinTransformer(
  713. embed_dim=192,
  714. depths=[2, 2, 18, 2],
  715. num_heads=[6, 12, 24, 48],
  716. window_size=7,
  717. **kwargs)
  718. _load_pretrained(
  719. pretrained,
  720. model,
  721. MODEL_URLS["SwinTransformer_large_patch4_window7_224"],
  722. use_ssld=use_ssld)
  723. return model
  724. def SwinTransformer_large_patch4_window12_384(pretrained=False,
  725. use_ssld=False,
  726. **kwargs):
  727. model = SwinTransformer(
  728. img_size=384,
  729. embed_dim=192,
  730. depths=[2, 2, 18, 2],
  731. num_heads=[6, 12, 24, 48],
  732. window_size=12,
  733. **kwargs)
  734. _load_pretrained(
  735. pretrained,
  736. model,
  737. MODEL_URLS["SwinTransformer_large_patch4_window12_384"],
  738. use_ssld=use_ssld)
  739. return model