tnt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch
  15. import math
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. from paddle.nn.initializer import TruncatedNormal, Constant
  20. from paddlex.ppcls.arch.backbone.base.theseus_layer import Identity
  21. from paddlex.ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
  22. MODEL_URLS = {
  23. "TNT_small":
  24. "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/TNT_small_pretrained.pdparams"
  25. }
  26. __all__ = MODEL_URLS.keys()
  27. trunc_normal_ = TruncatedNormal(std=.02)
  28. zeros_ = Constant(value=0.)
  29. ones_ = Constant(value=1.)
  30. def drop_path(x, drop_prob=0., training=False):
  31. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  32. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  33. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  34. """
  35. if drop_prob == 0. or not training:
  36. return x
  37. keep_prob = paddle.to_tensor(1 - drop_prob)
  38. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  39. random_tensor = paddle.add(keep_prob, paddle.rand(shape, dtype=x.dtype))
  40. random_tensor = paddle.floor(random_tensor) # binarize
  41. output = x.divide(keep_prob) * random_tensor
  42. return output
  43. class DropPath(nn.Layer):
  44. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  45. """
  46. def __init__(self, drop_prob=None):
  47. super(DropPath, self).__init__()
  48. self.drop_prob = drop_prob
  49. def forward(self, x):
  50. return drop_path(x, self.drop_prob, self.training)
  51. class Mlp(nn.Layer):
  52. def __init__(self,
  53. in_features,
  54. hidden_features=None,
  55. out_features=None,
  56. act_layer=nn.GELU,
  57. drop=0.):
  58. super().__init__()
  59. out_features = out_features or in_features
  60. hidden_features = hidden_features or in_features
  61. self.fc1 = nn.Linear(in_features, hidden_features)
  62. self.act = act_layer()
  63. self.fc2 = nn.Linear(hidden_features, out_features)
  64. self.drop = nn.Dropout(drop)
  65. def forward(self, x):
  66. x = self.fc1(x)
  67. x = self.act(x)
  68. x = self.drop(x)
  69. x = self.fc2(x)
  70. x = self.drop(x)
  71. return x
  72. class Attention(nn.Layer):
  73. def __init__(self,
  74. dim,
  75. hidden_dim,
  76. num_heads=8,
  77. qkv_bias=False,
  78. attn_drop=0.,
  79. proj_drop=0.):
  80. super().__init__()
  81. self.hidden_dim = hidden_dim
  82. self.num_heads = num_heads
  83. head_dim = hidden_dim // num_heads
  84. self.head_dim = head_dim
  85. self.scale = head_dim**-0.5
  86. self.qk = nn.Linear(dim, hidden_dim * 2, bias_attr=qkv_bias)
  87. self.v = nn.Linear(dim, dim, bias_attr=qkv_bias)
  88. self.attn_drop = nn.Dropout(attn_drop)
  89. self.proj = nn.Linear(dim, dim)
  90. self.proj_drop = nn.Dropout(proj_drop)
  91. def forward(self, x):
  92. B, N, C = x.shape
  93. qk = self.qk(x).reshape(
  94. (B, N, 2, self.num_heads, self.head_dim)).transpose(
  95. (2, 0, 3, 1, 4))
  96. q, k = qk[0], qk[1]
  97. v = self.v(x).reshape(
  98. (B, N, self.num_heads, x.shape[-1] // self.num_heads)).transpose(
  99. (0, 2, 1, 3))
  100. attn = paddle.matmul(q, k.transpose((0, 1, 3, 2))) * self.scale
  101. attn = nn.functional.softmax(attn, axis=-1)
  102. attn = self.attn_drop(attn)
  103. x = paddle.matmul(attn, v)
  104. x = x.transpose((0, 2, 1, 3)).reshape(
  105. (B, N, x.shape[-1] * x.shape[-3]))
  106. x = self.proj(x)
  107. x = self.proj_drop(x)
  108. return x
  109. class Block(nn.Layer):
  110. def __init__(self,
  111. dim,
  112. in_dim,
  113. num_pixel,
  114. num_heads=12,
  115. in_num_head=4,
  116. mlp_ratio=4.,
  117. qkv_bias=False,
  118. drop=0.,
  119. attn_drop=0.,
  120. drop_path=0.,
  121. act_layer=nn.GELU,
  122. norm_layer=nn.LayerNorm):
  123. super().__init__()
  124. # Inner transformer
  125. self.norm_in = norm_layer(in_dim)
  126. self.attn_in = Attention(
  127. in_dim,
  128. in_dim,
  129. num_heads=in_num_head,
  130. qkv_bias=qkv_bias,
  131. attn_drop=attn_drop,
  132. proj_drop=drop)
  133. self.norm_mlp_in = norm_layer(in_dim)
  134. self.mlp_in = Mlp(in_features=in_dim,
  135. hidden_features=int(in_dim * 4),
  136. out_features=in_dim,
  137. act_layer=act_layer,
  138. drop=drop)
  139. self.norm1_proj = norm_layer(in_dim)
  140. self.proj = nn.Linear(in_dim * num_pixel, dim)
  141. # Outer transformer
  142. self.norm_out = norm_layer(dim)
  143. self.attn_out = Attention(
  144. dim,
  145. dim,
  146. num_heads=num_heads,
  147. qkv_bias=qkv_bias,
  148. attn_drop=attn_drop,
  149. proj_drop=drop)
  150. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  151. self.norm_mlp = norm_layer(dim)
  152. self.mlp = Mlp(in_features=dim,
  153. hidden_features=int(dim * mlp_ratio),
  154. out_features=dim,
  155. act_layer=act_layer,
  156. drop=drop)
  157. def forward(self, pixel_embed, patch_embed):
  158. # inner
  159. pixel_embed = paddle.add(
  160. pixel_embed,
  161. self.drop_path(self.attn_in(self.norm_in(pixel_embed))))
  162. pixel_embed = paddle.add(
  163. pixel_embed,
  164. self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))))
  165. # outer
  166. B, N, C = patch_embed.shape
  167. norm1_proj = self.norm1_proj(pixel_embed)
  168. norm1_proj = norm1_proj.reshape(
  169. (B, N - 1, norm1_proj.shape[1] * norm1_proj.shape[2]))
  170. patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:],
  171. self.proj(norm1_proj))
  172. patch_embed = paddle.add(
  173. patch_embed,
  174. self.drop_path(self.attn_out(self.norm_out(patch_embed))))
  175. patch_embed = paddle.add(
  176. patch_embed, self.drop_path(self.mlp(self.norm_mlp(patch_embed))))
  177. return pixel_embed, patch_embed
  178. class PixelEmbed(nn.Layer):
  179. def __init__(self,
  180. img_size=224,
  181. patch_size=16,
  182. in_chans=3,
  183. in_dim=48,
  184. stride=4):
  185. super().__init__()
  186. num_patches = (img_size // patch_size)**2
  187. self.img_size = img_size
  188. self.num_patches = num_patches
  189. self.in_dim = in_dim
  190. new_patch_size = math.ceil(patch_size / stride)
  191. self.new_patch_size = new_patch_size
  192. self.proj = nn.Conv2D(
  193. in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
  194. def forward(self, x, pixel_pos):
  195. B, C, H, W = x.shape
  196. assert H == self.img_size and W == self.img_size, f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
  197. x = self.proj(x)
  198. x = nn.functional.unfold(x, self.new_patch_size, self.new_patch_size)
  199. x = x.transpose((0, 2, 1)).reshape(
  200. (-1, self.in_dim, self.new_patch_size, self.new_patch_size))
  201. x = x + pixel_pos
  202. x = x.reshape((-1, self.in_dim, self.new_patch_size *
  203. self.new_patch_size)).transpose((0, 2, 1))
  204. return x
  205. class TNT(nn.Layer):
  206. def __init__(self,
  207. img_size=224,
  208. patch_size=16,
  209. in_chans=3,
  210. embed_dim=768,
  211. in_dim=48,
  212. depth=12,
  213. num_heads=12,
  214. in_num_head=4,
  215. mlp_ratio=4.,
  216. qkv_bias=False,
  217. drop_rate=0.,
  218. attn_drop_rate=0.,
  219. drop_path_rate=0.,
  220. norm_layer=nn.LayerNorm,
  221. first_stride=4,
  222. class_num=1000):
  223. super().__init__()
  224. self.class_num = class_num
  225. # num_features for consistency with other models
  226. self.num_features = self.embed_dim = embed_dim
  227. self.pixel_embed = PixelEmbed(
  228. img_size=img_size,
  229. patch_size=patch_size,
  230. in_chans=in_chans,
  231. in_dim=in_dim,
  232. stride=first_stride)
  233. num_patches = self.pixel_embed.num_patches
  234. self.num_patches = num_patches
  235. new_patch_size = self.pixel_embed.new_patch_size
  236. num_pixel = new_patch_size**2
  237. self.norm1_proj = norm_layer(num_pixel * in_dim)
  238. self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
  239. self.norm2_proj = norm_layer(embed_dim)
  240. self.cls_token = self.create_parameter(
  241. shape=(1, 1, embed_dim), default_initializer=zeros_)
  242. self.add_parameter("cls_token", self.cls_token)
  243. self.patch_pos = self.create_parameter(
  244. shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
  245. self.add_parameter("patch_pos", self.patch_pos)
  246. self.pixel_pos = self.create_parameter(
  247. shape=(1, in_dim, new_patch_size, new_patch_size),
  248. default_initializer=zeros_)
  249. self.add_parameter("pixel_pos", self.pixel_pos)
  250. self.pos_drop = nn.Dropout(p=drop_rate)
  251. # stochastic depth decay rule
  252. dpr = np.linspace(0, drop_path_rate, depth)
  253. blocks = []
  254. for i in range(depth):
  255. blocks.append(
  256. Block(
  257. dim=embed_dim,
  258. in_dim=in_dim,
  259. num_pixel=num_pixel,
  260. num_heads=num_heads,
  261. in_num_head=in_num_head,
  262. mlp_ratio=mlp_ratio,
  263. qkv_bias=qkv_bias,
  264. drop=drop_rate,
  265. attn_drop=attn_drop_rate,
  266. drop_path=dpr[i],
  267. norm_layer=norm_layer))
  268. self.blocks = nn.LayerList(blocks)
  269. self.norm = norm_layer(embed_dim)
  270. if class_num > 0:
  271. self.head = nn.Linear(embed_dim, class_num)
  272. trunc_normal_(self.cls_token)
  273. trunc_normal_(self.patch_pos)
  274. trunc_normal_(self.pixel_pos)
  275. self.apply(self._init_weights)
  276. def _init_weights(self, m):
  277. if isinstance(m, nn.Linear):
  278. trunc_normal_(m.weight)
  279. if isinstance(m, nn.Linear) and m.bias is not None:
  280. zeros_(m.bias)
  281. elif isinstance(m, nn.LayerNorm):
  282. zeros_(m.bias)
  283. ones_(m.weight)
  284. def forward_features(self, x):
  285. B = paddle.shape(x)[0]
  286. pixel_embed = self.pixel_embed(x, self.pixel_pos)
  287. patch_embed = self.norm2_proj(
  288. self.proj(
  289. self.norm1_proj(
  290. pixel_embed.reshape((-1, self.num_patches, pixel_embed.
  291. shape[-1] * pixel_embed.shape[-2])))))
  292. patch_embed = paddle.concat(
  293. (self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
  294. patch_embed = patch_embed + self.patch_pos
  295. patch_embed = self.pos_drop(patch_embed)
  296. for blk in self.blocks:
  297. pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
  298. patch_embed = self.norm(patch_embed)
  299. return patch_embed[:, 0]
  300. def forward(self, x):
  301. x = self.forward_features(x)
  302. if self.class_num > 0:
  303. x = self.head(x)
  304. return x
  305. def _load_pretrained(pretrained, model, model_url, use_ssld=False):
  306. if pretrained is False:
  307. pass
  308. elif pretrained is True:
  309. load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
  310. elif isinstance(pretrained, str):
  311. load_dygraph_pretrain(model, pretrained)
  312. else:
  313. raise RuntimeError(
  314. "pretrained type is not available. Please use `string` or `boolean` type."
  315. )
  316. def TNT_small(pretrained=False, **kwargs):
  317. model = TNT(patch_size=16,
  318. embed_dim=384,
  319. in_dim=24,
  320. depth=12,
  321. num_heads=6,
  322. in_num_head=4,
  323. qkv_bias=False,
  324. **kwargs)
  325. _load_pretrained(pretrained, model, MODEL_URLS["TNT_small"])
  326. return model