detr_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddlex.ppdet.core.workspace import register
  21. import pycocotools.mask as mask_util
  22. from ..initializer import linear_init_, constant_
  23. from ..transformers.utils import inverse_sigmoid
  24. __all__ = ['DETRHead', 'DeformableDETRHead']
  25. class MLP(nn.Layer):
  26. """This code is based on
  27. https://github.com/facebookresearch/detr/blob/main/models/detr.py
  28. """
  29. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  30. super().__init__()
  31. self.num_layers = num_layers
  32. h = [hidden_dim] * (num_layers - 1)
  33. self.layers = nn.LayerList(
  34. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  35. self._reset_parameters()
  36. def _reset_parameters(self):
  37. for l in self.layers:
  38. linear_init_(l)
  39. def forward(self, x):
  40. for i, layer in enumerate(self.layers):
  41. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  42. return x
  43. class MultiHeadAttentionMap(nn.Layer):
  44. """This code is based on
  45. https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
  46. This is a 2D attention module, which only returns the attention softmax (no multiplication by value)
  47. """
  48. def __init__(self,
  49. query_dim,
  50. hidden_dim,
  51. num_heads,
  52. dropout=0.0,
  53. bias=True):
  54. super().__init__()
  55. self.num_heads = num_heads
  56. self.hidden_dim = hidden_dim
  57. self.dropout = nn.Dropout(dropout)
  58. weight_attr = paddle.ParamAttr(
  59. initializer=paddle.nn.initializer.XavierUniform())
  60. bias_attr = paddle.framework.ParamAttr(
  61. initializer=paddle.nn.initializer.Constant()) if bias else False
  62. self.q_proj = nn.Linear(query_dim, hidden_dim, weight_attr, bias_attr)
  63. self.k_proj = nn.Conv2D(
  64. query_dim,
  65. hidden_dim,
  66. 1,
  67. weight_attr=weight_attr,
  68. bias_attr=bias_attr)
  69. self.normalize_fact = float(hidden_dim / self.num_heads)**-0.5
  70. def forward(self, q, k, mask=None):
  71. q = self.q_proj(q)
  72. k = self.k_proj(k)
  73. bs, num_queries, n, c, h, w = q.shape[0], q.shape[1], self.num_heads,\
  74. self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]
  75. qh = q.reshape([bs, num_queries, n, c])
  76. kh = k.reshape([bs, n, c, h, w])
  77. # weights = paddle.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
  78. qh = qh.transpose([0, 2, 1, 3]).reshape([-1, num_queries, c])
  79. kh = kh.reshape([-1, c, h * w])
  80. weights = paddle.bmm(qh * self.normalize_fact, kh).reshape(
  81. [bs, n, num_queries, h, w]).transpose([0, 2, 1, 3, 4])
  82. if mask is not None:
  83. weights += mask
  84. # fix a potenial bug: https://github.com/facebookresearch/detr/issues/247
  85. weights = F.softmax(weights.flatten(3), axis=-1).reshape(weights.shape)
  86. weights = self.dropout(weights)
  87. return weights
  88. class MaskHeadFPNConv(nn.Layer):
  89. """This code is based on
  90. https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
  91. Simple convolutional head, using group norm.
  92. Upsampling is done using a FPN approach
  93. """
  94. def __init__(self, input_dim, fpn_dims, context_dim, num_groups=8):
  95. super().__init__()
  96. inter_dims = [input_dim,
  97. ] + [context_dim // (2**i) for i in range(1, 5)]
  98. weight_attr = paddle.ParamAttr(
  99. initializer=paddle.nn.initializer.KaimingUniform())
  100. bias_attr = paddle.framework.ParamAttr(
  101. initializer=paddle.nn.initializer.Constant())
  102. self.conv0 = self._make_layers(input_dim, input_dim, 3, num_groups,
  103. weight_attr, bias_attr)
  104. self.conv_inter = nn.LayerList()
  105. for in_dims, out_dims in zip(inter_dims[:-1], inter_dims[1:]):
  106. self.conv_inter.append(
  107. self._make_layers(in_dims, out_dims, 3, num_groups,
  108. weight_attr, bias_attr))
  109. self.conv_out = nn.Conv2D(
  110. inter_dims[-1],
  111. 1,
  112. 3,
  113. padding=1,
  114. weight_attr=weight_attr,
  115. bias_attr=bias_attr)
  116. self.adapter = nn.LayerList()
  117. for i in range(len(fpn_dims)):
  118. self.adapter.append(
  119. nn.Conv2D(
  120. fpn_dims[i],
  121. inter_dims[i + 1],
  122. 1,
  123. weight_attr=weight_attr,
  124. bias_attr=bias_attr))
  125. def _make_layers(self,
  126. in_dims,
  127. out_dims,
  128. kernel_size,
  129. num_groups,
  130. weight_attr=None,
  131. bias_attr=None):
  132. return nn.Sequential(
  133. nn.Conv2D(
  134. in_dims,
  135. out_dims,
  136. kernel_size,
  137. padding=kernel_size // 2,
  138. weight_attr=weight_attr,
  139. bias_attr=bias_attr),
  140. nn.GroupNorm(num_groups, out_dims),
  141. nn.ReLU())
  142. def forward(self, x, bbox_attention_map, fpns):
  143. x = paddle.concat([
  144. x.tile([bbox_attention_map.shape[1], 1, 1, 1]),
  145. bbox_attention_map.flatten(0, 1)
  146. ], 1)
  147. x = self.conv0(x)
  148. for inter_layer, adapter_layer, feat in zip(self.conv_inter[:-1],
  149. self.adapter, fpns):
  150. feat = adapter_layer(feat).tile(
  151. [bbox_attention_map.shape[1], 1, 1, 1])
  152. x = inter_layer(x)
  153. x = feat + F.interpolate(x, size=feat.shape[-2:])
  154. x = self.conv_inter[-1](x)
  155. x = self.conv_out(x)
  156. return x
  157. @register
  158. class DETRHead(nn.Layer):
  159. __shared__ = ['num_classes', 'hidden_dim', 'use_focal_loss']
  160. __inject__ = ['loss']
  161. def __init__(self,
  162. num_classes=80,
  163. hidden_dim=256,
  164. nhead=8,
  165. num_mlp_layers=3,
  166. loss='DETRLoss',
  167. fpn_dims=[1024, 512, 256],
  168. with_mask_head=False,
  169. use_focal_loss=False):
  170. super(DETRHead, self).__init__()
  171. # add background class
  172. self.num_classes = num_classes if use_focal_loss else num_classes + 1
  173. self.hidden_dim = hidden_dim
  174. self.loss = loss
  175. self.with_mask_head = with_mask_head
  176. self.use_focal_loss = use_focal_loss
  177. self.score_head = nn.Linear(hidden_dim, self.num_classes)
  178. self.bbox_head = MLP(hidden_dim,
  179. hidden_dim,
  180. output_dim=4,
  181. num_layers=num_mlp_layers)
  182. if self.with_mask_head:
  183. self.bbox_attention = MultiHeadAttentionMap(hidden_dim, hidden_dim,
  184. nhead)
  185. self.mask_head = MaskHeadFPNConv(hidden_dim + nhead, fpn_dims,
  186. hidden_dim)
  187. self._reset_parameters()
  188. def _reset_parameters(self):
  189. linear_init_(self.score_head)
  190. @classmethod
  191. def from_config(cls, cfg, hidden_dim, nhead, input_shape):
  192. return {
  193. 'hidden_dim': hidden_dim,
  194. 'nhead': nhead,
  195. 'fpn_dims': [i.channels for i in input_shape[::-1]][1:]
  196. }
  197. @staticmethod
  198. def get_gt_mask_from_polygons(gt_poly, pad_mask):
  199. out_gt_mask = []
  200. for polygons, padding in zip(gt_poly, pad_mask):
  201. height, width = int(padding[:, 0].sum()), int(padding[0, :].sum())
  202. masks = []
  203. for obj_poly in polygons:
  204. rles = mask_util.frPyObjects(obj_poly, height, width)
  205. rle = mask_util.merge(rles)
  206. masks.append(
  207. paddle.to_tensor(mask_util.decode(rle)).astype('float32'))
  208. masks = paddle.stack(masks)
  209. masks_pad = paddle.zeros(
  210. [masks.shape[0], pad_mask.shape[1], pad_mask.shape[2]])
  211. masks_pad[:, :height, :width] = masks
  212. out_gt_mask.append(masks_pad)
  213. return out_gt_mask
  214. def forward(self, out_transformer, body_feats, inputs=None):
  215. r"""
  216. Args:
  217. out_transformer (Tuple): (feats: [num_levels, batch_size,
  218. num_queries, hidden_dim],
  219. memory: [batch_size, hidden_dim, h, w],
  220. src_proj: [batch_size, h*w, hidden_dim],
  221. src_mask: [batch_size, 1, 1, h, w])
  222. body_feats (List(Tensor)): list[[B, C, H, W]]
  223. inputs (dict): dict(inputs)
  224. """
  225. feats, memory, src_proj, src_mask = out_transformer
  226. outputs_logit = self.score_head(feats)
  227. outputs_bbox = F.sigmoid(self.bbox_head(feats))
  228. outputs_seg = None
  229. if self.with_mask_head:
  230. bbox_attention_map = self.bbox_attention(feats[-1], memory,
  231. src_mask)
  232. fpn_feats = [a for a in body_feats[::-1]][1:]
  233. outputs_seg = self.mask_head(src_proj, bbox_attention_map,
  234. fpn_feats)
  235. outputs_seg = outputs_seg.reshape([
  236. feats.shape[1], feats.shape[2], outputs_seg.shape[-2],
  237. outputs_seg.shape[-1]
  238. ])
  239. if self.training:
  240. assert inputs is not None
  241. assert 'gt_bbox' in inputs and 'gt_class' in inputs
  242. gt_mask = self.get_gt_mask_from_polygons(
  243. inputs['gt_poly'],
  244. inputs['pad_mask']) if 'gt_poly' in inputs else None
  245. return self.loss(
  246. outputs_bbox,
  247. outputs_logit,
  248. inputs['gt_bbox'],
  249. inputs['gt_class'],
  250. masks=outputs_seg,
  251. gt_mask=gt_mask)
  252. else:
  253. return (outputs_bbox[-1], outputs_logit[-1], outputs_seg)
  254. @register
  255. class DeformableDETRHead(nn.Layer):
  256. __shared__ = ['num_classes', 'hidden_dim']
  257. __inject__ = ['loss']
  258. def __init__(self,
  259. num_classes=80,
  260. hidden_dim=512,
  261. nhead=8,
  262. num_mlp_layers=3,
  263. loss='DETRLoss'):
  264. super(DeformableDETRHead, self).__init__()
  265. self.num_classes = num_classes
  266. self.hidden_dim = hidden_dim
  267. self.nhead = nhead
  268. self.loss = loss
  269. self.score_head = nn.Linear(hidden_dim, self.num_classes)
  270. self.bbox_head = MLP(hidden_dim,
  271. hidden_dim,
  272. output_dim=4,
  273. num_layers=num_mlp_layers)
  274. self._reset_parameters()
  275. def _reset_parameters(self):
  276. linear_init_(self.score_head)
  277. constant_(self.score_head.bias, -4.595)
  278. constant_(self.bbox_head.layers[-1].weight)
  279. with paddle.no_grad():
  280. bias = paddle.zeros_like(self.bbox_head.layers[-1].bias)
  281. bias[2:] = -2.0
  282. self.bbox_head.layers[-1].bias.set_value(bias)
  283. @classmethod
  284. def from_config(cls, cfg, hidden_dim, nhead, input_shape):
  285. return {'hidden_dim': hidden_dim, 'nhead': nhead}
  286. def forward(self, out_transformer, body_feats, inputs=None):
  287. r"""
  288. Args:
  289. out_transformer (Tuple): (feats: [num_levels, batch_size,
  290. num_queries, hidden_dim],
  291. memory: [batch_size,
  292. \sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim],
  293. reference_points: [batch_size, num_queries, 2])
  294. body_feats (List(Tensor)): list[[B, C, H, W]]
  295. inputs (dict): dict(inputs)
  296. """
  297. feats, memory, reference_points = out_transformer
  298. reference_points = inverse_sigmoid(reference_points.unsqueeze(0))
  299. outputs_bbox = self.bbox_head(feats)
  300. # It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points",
  301. # but the gradient is wrong in paddle.
  302. outputs_bbox = paddle.concat(
  303. [
  304. outputs_bbox[:, :, :, :2] + reference_points,
  305. outputs_bbox[:, :, :, 2:]
  306. ],
  307. axis=-1)
  308. outputs_bbox = F.sigmoid(outputs_bbox)
  309. outputs_logit = self.score_head(feats)
  310. if self.training:
  311. assert inputs is not None
  312. assert 'gt_bbox' in inputs and 'gt_class' in inputs
  313. return self.loss(outputs_bbox, outputs_logit, inputs['gt_bbox'],
  314. inputs['gt_class'])
  315. else:
  316. return (outputs_bbox[-1], outputs_logit[-1], None)