gvt.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/Meituan-AutoML/Twins
  15. from functools import partial
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.regularizer import L2Decay
  20. from .vision_transformer import trunc_normal_, normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp
  21. from .vision_transformer import Block as ViTBlock
  22. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  23. MODEL_URLS = {
  24. "pcpvt_small":
  25. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_small_pretrained.pdparams",
  26. "pcpvt_base":
  27. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_base_pretrained.pdparams",
  28. "pcpvt_large":
  29. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/pcpvt_large_pretrained.pdparams",
  30. "alt_gvt_small":
  31. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_small_pretrained.pdparams",
  32. "alt_gvt_base":
  33. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_base_pretrained.pdparams",
  34. "alt_gvt_large":
  35. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/alt_gvt_large_pretrained.pdparams"
  36. }
  37. __all__ = list(MODEL_URLS.keys())
  38. class GroupAttention(nn.Layer):
  39. """LSA: self attention within a group.
  40. """
  41. def __init__(self,
  42. dim,
  43. num_heads=8,
  44. qkv_bias=False,
  45. qk_scale=None,
  46. attn_drop=0.,
  47. proj_drop=0.,
  48. ws=1):
  49. super().__init__()
  50. if ws == 1:
  51. raise Exception("ws {ws} should not be 1")
  52. if dim % num_heads != 0:
  53. raise Exception(
  54. "dim {dim} should be divided by num_heads {num_heads}.")
  55. self.dim = dim
  56. self.num_heads = num_heads
  57. head_dim = dim // num_heads
  58. self.scale = qk_scale or head_dim**-0.5
  59. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  60. self.attn_drop = nn.Dropout(attn_drop)
  61. self.proj = nn.Linear(dim, dim)
  62. self.proj_drop = nn.Dropout(proj_drop)
  63. self.ws = ws
  64. def forward(self, x, H, W):
  65. B, N, C = x.shape
  66. h_group, w_group = H // self.ws, W // self.ws
  67. total_groups = h_group * w_group
  68. x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
  69. [0, 1, 3, 2, 4, 5])
  70. qkv = self.qkv(x).reshape([
  71. B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
  72. ]).transpose([3, 0, 1, 4, 2, 5])
  73. q, k, v = qkv[0], qkv[1], qkv[2]
  74. attn = paddle.matmul(q, k.transpose([0, 1, 2, 4, 3])) * self.scale
  75. attn = nn.Softmax(axis=-1)(attn)
  76. attn = self.attn_drop(attn)
  77. attn = paddle.matmul(attn, v).transpose([0, 1, 3, 2, 4]).reshape(
  78. [B, h_group, w_group, self.ws, self.ws, C])
  79. x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
  80. x = self.proj(x)
  81. x = self.proj_drop(x)
  82. return x
  83. class Attention(nn.Layer):
  84. """GSA: using a key to summarize the information for a group to be efficient.
  85. """
  86. def __init__(self,
  87. dim,
  88. num_heads=8,
  89. qkv_bias=False,
  90. qk_scale=None,
  91. attn_drop=0.,
  92. proj_drop=0.,
  93. sr_ratio=1):
  94. super().__init__()
  95. assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
  96. self.dim = dim
  97. self.num_heads = num_heads
  98. head_dim = dim // num_heads
  99. self.scale = qk_scale or head_dim**-0.5
  100. self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
  101. self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
  102. self.attn_drop = nn.Dropout(attn_drop)
  103. self.proj = nn.Linear(dim, dim)
  104. self.proj_drop = nn.Dropout(proj_drop)
  105. self.sr_ratio = sr_ratio
  106. if sr_ratio > 1:
  107. self.sr = nn.Conv2D(
  108. dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
  109. self.norm = nn.LayerNorm(dim)
  110. def forward(self, x, H, W):
  111. B, N, C = x.shape
  112. q = self.q(x).reshape(
  113. [B, N, self.num_heads, C // self.num_heads]).transpose(
  114. [0, 2, 1, 3])
  115. if self.sr_ratio > 1:
  116. x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
  117. tmp_n = H * W // self.sr_ratio**2
  118. x_ = self.sr(x_).reshape([B, C, tmp_n]).transpose([0, 2, 1])
  119. x_ = self.norm(x_)
  120. kv = self.kv(x_).reshape(
  121. [B, tmp_n, 2, self.num_heads, C // self.num_heads]).transpose(
  122. [2, 0, 3, 1, 4])
  123. else:
  124. kv = self.kv(x).reshape(
  125. [B, N, 2, self.num_heads, C // self.num_heads]).transpose(
  126. [2, 0, 3, 1, 4])
  127. k, v = kv[0], kv[1]
  128. attn = paddle.matmul(q, k.transpose([0, 1, 3, 2])) * self.scale
  129. attn = nn.Softmax(axis=-1)(attn)
  130. attn = self.attn_drop(attn)
  131. x = paddle.matmul(attn, v).transpose([0, 2, 1, 3]).reshape([B, N, C])
  132. x = self.proj(x)
  133. x = self.proj_drop(x)
  134. return x
  135. class Block(nn.Layer):
  136. def __init__(self,
  137. dim,
  138. num_heads,
  139. mlp_ratio=4.,
  140. qkv_bias=False,
  141. qk_scale=None,
  142. drop=0.,
  143. attn_drop=0.,
  144. drop_path=0.,
  145. act_layer=nn.GELU,
  146. norm_layer=nn.LayerNorm,
  147. sr_ratio=1):
  148. super().__init__()
  149. self.norm1 = norm_layer(dim)
  150. self.attn = Attention(
  151. dim,
  152. num_heads=num_heads,
  153. qkv_bias=qkv_bias,
  154. qk_scale=qk_scale,
  155. attn_drop=attn_drop,
  156. proj_drop=drop,
  157. sr_ratio=sr_ratio)
  158. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  159. self.norm2 = norm_layer(dim)
  160. mlp_hidden_dim = int(dim * mlp_ratio)
  161. self.mlp = Mlp(in_features=dim,
  162. hidden_features=mlp_hidden_dim,
  163. act_layer=act_layer,
  164. drop=drop)
  165. def forward(self, x, H, W):
  166. x = x + self.drop_path(self.attn(self.norm1(x), H, W))
  167. x = x + self.drop_path(self.mlp(self.norm2(x)))
  168. return x
  169. class SBlock(ViTBlock):
  170. def __init__(self,
  171. dim,
  172. num_heads,
  173. mlp_ratio=4.,
  174. qkv_bias=False,
  175. qk_scale=None,
  176. drop=0.,
  177. attn_drop=0.,
  178. drop_path=0.,
  179. act_layer=nn.GELU,
  180. norm_layer=nn.LayerNorm,
  181. sr_ratio=1):
  182. super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
  183. attn_drop, drop_path, act_layer, norm_layer)
  184. def forward(self, x, H, W):
  185. return super().forward(x)
  186. class GroupBlock(ViTBlock):
  187. def __init__(self,
  188. dim,
  189. num_heads,
  190. mlp_ratio=4.,
  191. qkv_bias=False,
  192. qk_scale=None,
  193. drop=0.,
  194. attn_drop=0.,
  195. drop_path=0.,
  196. act_layer=nn.GELU,
  197. norm_layer=nn.LayerNorm,
  198. sr_ratio=1,
  199. ws=1):
  200. super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
  201. attn_drop, drop_path, act_layer, norm_layer)
  202. del self.attn
  203. if ws == 1:
  204. self.attn = Attention(dim, num_heads, qkv_bias, qk_scale,
  205. attn_drop, drop, sr_ratio)
  206. else:
  207. self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale,
  208. attn_drop, drop, ws)
  209. def forward(self, x, H, W):
  210. x = x + self.drop_path(self.attn(self.norm1(x), H, W))
  211. x = x + self.drop_path(self.mlp(self.norm2(x)))
  212. return x
  213. class PatchEmbed(nn.Layer):
  214. """ Image to Patch Embedding.
  215. """
  216. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  217. super().__init__()
  218. if img_size % patch_size != 0:
  219. raise Exception(
  220. f"img_size {img_size} should be divided by patch_size {patch_size}."
  221. )
  222. img_size = to_2tuple(img_size)
  223. patch_size = to_2tuple(patch_size)
  224. self.img_size = img_size
  225. self.patch_size = patch_size
  226. self.H, self.W = img_size[0] // patch_size[0], img_size[
  227. 1] // patch_size[1]
  228. self.num_patches = self.H * self.W
  229. self.proj = nn.Conv2D(
  230. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  231. self.norm = nn.LayerNorm(embed_dim)
  232. def forward(self, x):
  233. B, C, H, W = x.shape
  234. x = self.proj(x).flatten(2).transpose([0, 2, 1])
  235. x = self.norm(x)
  236. H, W = H // self.patch_size[0], W // self.patch_size[1]
  237. return x, (H, W)
  238. # borrow from PVT https://github.com/whai362/PVT.git
  239. class PyramidVisionTransformer(nn.Layer):
  240. def __init__(self,
  241. img_size=224,
  242. patch_size=16,
  243. in_chans=3,
  244. class_num=1000,
  245. embed_dims=[64, 128, 256, 512],
  246. num_heads=[1, 2, 4, 8],
  247. mlp_ratios=[4, 4, 4, 4],
  248. qkv_bias=False,
  249. qk_scale=None,
  250. drop_rate=0.,
  251. attn_drop_rate=0.,
  252. drop_path_rate=0.,
  253. norm_layer=nn.LayerNorm,
  254. depths=[3, 4, 6, 3],
  255. sr_ratios=[8, 4, 2, 1],
  256. block_cls=Block):
  257. super().__init__()
  258. self.class_num = class_num
  259. self.depths = depths
  260. # patch_embed
  261. self.patch_embeds = nn.LayerList()
  262. self.pos_embeds = nn.ParameterList()
  263. self.pos_drops = nn.LayerList()
  264. self.blocks = nn.LayerList()
  265. for i in range(len(depths)):
  266. if i == 0:
  267. self.patch_embeds.append(
  268. PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
  269. else:
  270. self.patch_embeds.append(
  271. PatchEmbed(img_size // patch_size // 2**(i - 1), 2,
  272. embed_dims[i - 1], embed_dims[i]))
  273. patch_num = self.patch_embeds[i].num_patches + 1 if i == len(
  274. embed_dims) - 1 else self.patch_embeds[i].num_patches
  275. self.pos_embeds.append(
  276. self.create_parameter(
  277. shape=[1, patch_num, embed_dims[i]],
  278. default_initializer=zeros_))
  279. self.pos_drops.append(nn.Dropout(p=drop_rate))
  280. dpr = [
  281. x.numpy()[0]
  282. for x in paddle.linspace(0, drop_path_rate, sum(depths))
  283. ] # stochastic depth decay rule
  284. cur = 0
  285. for k in range(len(depths)):
  286. _block = nn.LayerList([
  287. block_cls(
  288. dim=embed_dims[k],
  289. num_heads=num_heads[k],
  290. mlp_ratio=mlp_ratios[k],
  291. qkv_bias=qkv_bias,
  292. qk_scale=qk_scale,
  293. drop=drop_rate,
  294. attn_drop=attn_drop_rate,
  295. drop_path=dpr[cur + i],
  296. norm_layer=norm_layer,
  297. sr_ratio=sr_ratios[k]) for i in range(depths[k])
  298. ])
  299. self.blocks.append(_block)
  300. cur += depths[k]
  301. self.norm = norm_layer(embed_dims[-1])
  302. # cls_token
  303. self.cls_token = self.create_parameter(
  304. shape=[1, 1, embed_dims[-1]],
  305. default_initializer=zeros_,
  306. attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
  307. # classification head
  308. self.head = nn.Linear(embed_dims[-1],
  309. class_num) if class_num > 0 else Identity()
  310. # init weights
  311. for pos_emb in self.pos_embeds:
  312. trunc_normal_(pos_emb)
  313. self.apply(self._init_weights)
  314. def _init_weights(self, m):
  315. if isinstance(m, nn.Linear):
  316. trunc_normal_(m.weight)
  317. if isinstance(m, nn.Linear) and m.bias is not None:
  318. zeros_(m.bias)
  319. elif isinstance(m, nn.LayerNorm):
  320. zeros_(m.bias)
  321. ones_(m.weight)
  322. def forward_features(self, x):
  323. B = x.shape[0]
  324. for i in range(len(self.depths)):
  325. x, (H, W) = self.patch_embeds[i](x)
  326. if i == len(self.depths) - 1:
  327. cls_tokens = self.cls_token.expand([B, -1, -1])
  328. x = paddle.concat([cls_tokens, x], dim=1)
  329. x = x + self.pos_embeds[i]
  330. x = self.pos_drops[i](x)
  331. for blk in self.blocks[i]:
  332. x = blk(x, H, W)
  333. if i < len(self.depths) - 1:
  334. x = x.reshape([B, H, W, -1]).transpose(
  335. [0, 3, 1, 2]).contiguous()
  336. x = self.norm(x)
  337. return x[:, 0]
  338. def forward(self, x):
  339. x = self.forward_features(x)
  340. x = self.head(x)
  341. return x
  342. # PEG from https://arxiv.org/abs/2102.10882
  343. class PosCNN(nn.Layer):
  344. def __init__(self, in_chans, embed_dim=768, s=1):
  345. super().__init__()
  346. self.proj = nn.Sequential(
  347. nn.Conv2D(
  348. in_chans,
  349. embed_dim,
  350. 3,
  351. s,
  352. 1,
  353. bias_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)),
  354. groups=embed_dim,
  355. weight_attr=paddle.ParamAttr(regularizer=L2Decay(0.0)), ))
  356. self.s = s
  357. def forward(self, x, H, W):
  358. B, N, C = x.shape
  359. feat_token = x
  360. cnn_feat = feat_token.transpose([0, 2, 1]).reshape([B, C, H, W])
  361. if self.s == 1:
  362. x = self.proj(cnn_feat) + cnn_feat
  363. else:
  364. x = self.proj(cnn_feat)
  365. x = x.flatten(2).transpose([0, 2, 1])
  366. return x
  367. class CPVTV2(PyramidVisionTransformer):
  368. """
  369. Use useful results from CPVT. PEG and GAP.
  370. Therefore, cls token is no longer required.
  371. PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution
  372. changes during the training (such as segmentation, detection)
  373. """
  374. def __init__(self,
  375. img_size=224,
  376. patch_size=4,
  377. in_chans=3,
  378. class_num=1000,
  379. embed_dims=[64, 128, 256, 512],
  380. num_heads=[1, 2, 4, 8],
  381. mlp_ratios=[4, 4, 4, 4],
  382. qkv_bias=False,
  383. qk_scale=None,
  384. drop_rate=0.,
  385. attn_drop_rate=0.,
  386. drop_path_rate=0.,
  387. norm_layer=nn.LayerNorm,
  388. depths=[3, 4, 6, 3],
  389. sr_ratios=[8, 4, 2, 1],
  390. block_cls=Block):
  391. super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
  392. num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
  393. attn_drop_rate, drop_path_rate, norm_layer, depths,
  394. sr_ratios, block_cls)
  395. del self.pos_embeds
  396. del self.cls_token
  397. self.pos_block = nn.LayerList(
  398. [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims])
  399. self.apply(self._init_weights)
  400. def _init_weights(self, m):
  401. import math
  402. if isinstance(m, nn.Linear):
  403. trunc_normal_(m.weight)
  404. if isinstance(m, nn.Linear) and m.bias is not None:
  405. zeros_(m.bias)
  406. elif isinstance(m, nn.LayerNorm):
  407. zeros_(m.bias)
  408. ones_(m.weight)
  409. elif isinstance(m, nn.Conv2D):
  410. fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  411. fan_out //= m._groups
  412. normal_(0, math.sqrt(2.0 / fan_out))(m.weight)
  413. if m.bias is not None:
  414. zeros_(m.bias)
  415. elif isinstance(m, nn.BatchNorm2D):
  416. m.weight.data.fill_(1.0)
  417. m.bias.data.zero_()
  418. def forward_features(self, x):
  419. B = x.shape[0]
  420. for i in range(len(self.depths)):
  421. x, (H, W) = self.patch_embeds[i](x)
  422. x = self.pos_drops[i](x)
  423. for j, blk in enumerate(self.blocks[i]):
  424. x = blk(x, H, W)
  425. if j == 0:
  426. x = self.pos_block[i](x, H, W) # PEG here
  427. if i < len(self.depths) - 1:
  428. x = x.reshape([B, H, W, x.shape[-1]]).transpose([0, 3, 1, 2])
  429. x = self.norm(x)
  430. return x.mean(axis=1) # GAP here
  431. class PCPVT(CPVTV2):
  432. def __init__(self,
  433. img_size=224,
  434. patch_size=4,
  435. in_chans=3,
  436. class_num=1000,
  437. embed_dims=[64, 128, 256],
  438. num_heads=[1, 2, 4],
  439. mlp_ratios=[4, 4, 4],
  440. qkv_bias=False,
  441. qk_scale=None,
  442. drop_rate=0.,
  443. attn_drop_rate=0.,
  444. drop_path_rate=0.,
  445. norm_layer=nn.LayerNorm,
  446. depths=[4, 4, 4],
  447. sr_ratios=[4, 2, 1],
  448. block_cls=SBlock):
  449. super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
  450. num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
  451. attn_drop_rate, drop_path_rate, norm_layer, depths,
  452. sr_ratios, block_cls)
  453. class ALTGVT(PCPVT):
  454. """
  455. alias Twins-SVT
  456. """
  457. def __init__(self,
  458. img_size=224,
  459. patch_size=4,
  460. in_chans=3,
  461. class_num=1000,
  462. embed_dims=[64, 128, 256],
  463. num_heads=[1, 2, 4],
  464. mlp_ratios=[4, 4, 4],
  465. qkv_bias=False,
  466. qk_scale=None,
  467. drop_rate=0.,
  468. attn_drop_rate=0.,
  469. drop_path_rate=0.,
  470. norm_layer=nn.LayerNorm,
  471. depths=[4, 4, 4],
  472. sr_ratios=[4, 2, 1],
  473. block_cls=GroupBlock,
  474. wss=[7, 7, 7]):
  475. super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
  476. num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
  477. attn_drop_rate, drop_path_rate, norm_layer, depths,
  478. sr_ratios, block_cls)
  479. del self.blocks
  480. self.wss = wss
  481. # transformer encoder
  482. dpr = [
  483. x.numpy()[0]
  484. for x in paddle.linspace(0, drop_path_rate, sum(depths))
  485. ] # stochastic depth decay rule
  486. cur = 0
  487. self.blocks = nn.LayerList()
  488. for k in range(len(depths)):
  489. _block = nn.LayerList([
  490. block_cls(
  491. dim=embed_dims[k],
  492. num_heads=num_heads[k],
  493. mlp_ratio=mlp_ratios[k],
  494. qkv_bias=qkv_bias,
  495. qk_scale=qk_scale,
  496. drop=drop_rate,
  497. attn_drop=attn_drop_rate,
  498. drop_path=dpr[cur + i],
  499. norm_layer=norm_layer,
  500. sr_ratio=sr_ratios[k],
  501. ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])
  502. ])
  503. self.blocks.append(_block)
  504. cur += depths[k]
  505. self.apply(self._init_weights)
  506. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  507. if pretrained is False:
  508. pass
  509. elif pretrained is True:
  510. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  511. elif isinstance(pretrained, str):
  512. load_dygraph_pretrain(model, pretrained)
  513. else:
  514. raise RuntimeError(
  515. "pretrained type is not available. Please use `string` or `boolean` type."
  516. )
  517. def pcpvt_small(pretrained=False, use_ssld=False, **kwargs):
  518. model = CPVTV2(
  519. patch_size=4,
  520. embed_dims=[64, 128, 320, 512],
  521. num_heads=[1, 2, 5, 8],
  522. mlp_ratios=[8, 8, 4, 4],
  523. qkv_bias=True,
  524. norm_layer=partial(
  525. nn.LayerNorm, epsilon=1e-6),
  526. depths=[3, 4, 6, 3],
  527. sr_ratios=[8, 4, 2, 1],
  528. **kwargs)
  529. _load_pretrained(
  530. pretrained, model, MODEL_URLS["pcpvt_small"], use_ssld=use_ssld)
  531. return model
  532. def pcpvt_base(pretrained=False, use_ssld=False, **kwargs):
  533. model = CPVTV2(
  534. patch_size=4,
  535. embed_dims=[64, 128, 320, 512],
  536. num_heads=[1, 2, 5, 8],
  537. mlp_ratios=[8, 8, 4, 4],
  538. qkv_bias=True,
  539. norm_layer=partial(
  540. nn.LayerNorm, epsilon=1e-6),
  541. depths=[3, 4, 18, 3],
  542. sr_ratios=[8, 4, 2, 1],
  543. **kwargs)
  544. _load_pretrained(
  545. pretrained, model, MODEL_URLS["pcpvt_base"], use_ssld=use_ssld)
  546. return model
  547. def pcpvt_large(pretrained=False, use_ssld=False, **kwargs):
  548. model = CPVTV2(
  549. patch_size=4,
  550. embed_dims=[64, 128, 320, 512],
  551. num_heads=[1, 2, 5, 8],
  552. mlp_ratios=[8, 8, 4, 4],
  553. qkv_bias=True,
  554. norm_layer=partial(
  555. nn.LayerNorm, epsilon=1e-6),
  556. depths=[3, 8, 27, 3],
  557. sr_ratios=[8, 4, 2, 1],
  558. **kwargs)
  559. _load_pretrained(
  560. pretrained, model, MODEL_URLS["pcpvt_large"], use_ssld=use_ssld)
  561. return model
  562. def alt_gvt_small(pretrained=False, use_ssld=False, **kwargs):
  563. model = ALTGVT(
  564. patch_size=4,
  565. embed_dims=[64, 128, 256, 512],
  566. num_heads=[2, 4, 8, 16],
  567. mlp_ratios=[4, 4, 4, 4],
  568. qkv_bias=True,
  569. norm_layer=partial(
  570. nn.LayerNorm, epsilon=1e-6),
  571. depths=[2, 2, 10, 4],
  572. wss=[7, 7, 7, 7],
  573. sr_ratios=[8, 4, 2, 1],
  574. **kwargs)
  575. _load_pretrained(
  576. pretrained, model, MODEL_URLS["alt_gvt_small"], use_ssld=use_ssld)
  577. return model
  578. def alt_gvt_base(pretrained=False, use_ssld=False, **kwargs):
  579. model = ALTGVT(
  580. patch_size=4,
  581. embed_dims=[96, 192, 384, 768],
  582. num_heads=[3, 6, 12, 24],
  583. mlp_ratios=[4, 4, 4, 4],
  584. qkv_bias=True,
  585. norm_layer=partial(
  586. nn.LayerNorm, epsilon=1e-6),
  587. depths=[2, 2, 18, 2],
  588. wss=[7, 7, 7, 7],
  589. sr_ratios=[8, 4, 2, 1],
  590. **kwargs)
  591. _load_pretrained(
  592. pretrained, model, MODEL_URLS["alt_gvt_base"], use_ssld=use_ssld)
  593. return model
  594. def alt_gvt_large(pretrained=False, use_ssld=False, **kwargs):
  595. model = ALTGVT(
  596. patch_size=4,
  597. embed_dims=[128, 256, 512, 1024],
  598. num_heads=[4, 8, 16, 32],
  599. mlp_ratios=[4, 4, 4, 4],
  600. qkv_bias=True,
  601. norm_layer=partial(
  602. nn.LayerNorm, epsilon=1e-6),
  603. depths=[2, 2, 18, 2],
  604. wss=[7, 7, 7, 7],
  605. sr_ratios=[8, 4, 2, 1],
  606. **kwargs)
  607. _load_pretrained(
  608. pretrained, model, MODEL_URLS["alt_gvt_large"], use_ssld=use_ssld)
  609. return model