fairmot_embedding_head.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import math
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn.initializer import KaimingUniform, Uniform
  20. from paddlex.ppdet.core.workspace import register
  21. from paddlex.ppdet.modeling.heads.centernet_head import ConvLayer
  22. __all__ = ['FairMOTEmbeddingHead']
  23. @register
  24. class FairMOTEmbeddingHead(nn.Layer):
  25. __shared__ = ['num_classes']
  26. """
  27. Args:
  28. in_channels (int): the channel number of input to FairMOTEmbeddingHead.
  29. ch_head (int): the channel of features before fed into embedding, 256 by default.
  30. ch_emb (int): the channel of the embedding feature, 128 by default.
  31. num_identities_dict (dict): the number of identities of each category,
  32. support single class and multi-calss, {0: 14455} as default.
  33. """
  34. def __init__(self,
  35. in_channels,
  36. ch_head=256,
  37. ch_emb=128,
  38. num_classes=1,
  39. num_identities_dict={0: 14455}):
  40. super(FairMOTEmbeddingHead, self).__init__()
  41. assert num_classes >= 1
  42. self.num_classes = num_classes
  43. self.ch_emb = ch_emb
  44. self.num_identities_dict = num_identities_dict
  45. self.reid = nn.Sequential(
  46. ConvLayer(
  47. in_channels, ch_head, kernel_size=3, padding=1, bias=True),
  48. nn.ReLU(),
  49. ConvLayer(
  50. ch_head, ch_emb, kernel_size=1, stride=1, padding=0,
  51. bias=True))
  52. param_attr = paddle.ParamAttr(initializer=KaimingUniform())
  53. bound = 1 / math.sqrt(ch_emb)
  54. bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
  55. self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
  56. if num_classes == 1:
  57. nID = self.num_identities_dict[0] # single class
  58. self.classifier = nn.Linear(
  59. ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
  60. # When num_identities(nID) is 1, emb_scale is set as 1
  61. self.emb_scale = math.sqrt(2) * math.log(nID - 1) if nID > 1 else 1
  62. else:
  63. self.classifiers = dict()
  64. self.emb_scale_dict = dict()
  65. for cls_id, nID in self.num_identities_dict.items():
  66. self.classifiers[str(cls_id)] = nn.Linear(
  67. ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
  68. # When num_identities(nID) is 1, emb_scale is set as 1
  69. self.emb_scale_dict[str(cls_id)] = math.sqrt(2) * math.log(
  70. nID - 1) if nID > 1 else 1
  71. @classmethod
  72. def from_config(cls, cfg, input_shape):
  73. if isinstance(input_shape, (list, tuple)):
  74. input_shape = input_shape[0]
  75. return {'in_channels': input_shape.channels}
  76. def process_by_class(self, bboxes, embedding, bbox_inds, topk_clses):
  77. pred_dets, pred_embs = [], []
  78. for cls_id in range(self.num_classes):
  79. inds_masks = topk_clses == cls_id
  80. inds_masks = paddle.cast(inds_masks, 'float32')
  81. pos_num = inds_masks.sum().numpy()
  82. if pos_num == 0:
  83. continue
  84. cls_inds_mask = inds_masks > 0
  85. bbox_mask = paddle.nonzero(cls_inds_mask)
  86. cls_bboxes = paddle.gather_nd(bboxes, bbox_mask)
  87. pred_dets.append(cls_bboxes)
  88. cls_inds = paddle.masked_select(bbox_inds, cls_inds_mask)
  89. cls_inds = cls_inds.unsqueeze(-1)
  90. cls_embedding = paddle.gather_nd(embedding, cls_inds)
  91. pred_embs.append(cls_embedding)
  92. return paddle.concat(pred_dets), paddle.concat(pred_embs)
  93. def forward(self,
  94. neck_feat,
  95. inputs,
  96. bboxes=None,
  97. bbox_inds=None,
  98. topk_clses=None):
  99. reid_feat = self.reid(neck_feat)
  100. if self.training:
  101. if self.num_classes == 1:
  102. loss = self.get_loss(reid_feat, inputs)
  103. else:
  104. loss = self.get_mc_loss(reid_feat, inputs)
  105. return loss
  106. else:
  107. assert bboxes is not None and bbox_inds is not None
  108. reid_feat = F.normalize(reid_feat)
  109. embedding = paddle.transpose(reid_feat, [0, 2, 3, 1])
  110. embedding = paddle.reshape(embedding, [-1, self.ch_emb])
  111. # embedding shape: [bs * h * w, ch_emb]
  112. if self.num_classes == 1:
  113. pred_dets = bboxes
  114. pred_embs = paddle.gather(embedding, bbox_inds)
  115. else:
  116. pred_dets, pred_embs = self.process_by_class(
  117. bboxes, embedding, bbox_inds, topk_clses)
  118. return pred_dets, pred_embs
  119. def get_loss(self, feat, inputs):
  120. index = inputs['index']
  121. mask = inputs['index_mask']
  122. target = inputs['reid']
  123. target = paddle.masked_select(target, mask > 0)
  124. target = paddle.unsqueeze(target, 1)
  125. feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
  126. feat_n, feat_h, feat_w, feat_c = feat.shape
  127. feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
  128. index = paddle.unsqueeze(index, 2)
  129. batch_inds = list()
  130. for i in range(feat_n):
  131. batch_ind = paddle.full(
  132. shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
  133. batch_inds.append(batch_ind)
  134. batch_inds = paddle.concat(batch_inds, axis=0)
  135. index = paddle.concat(x=[batch_inds, index], axis=2)
  136. feat = paddle.gather_nd(feat, index=index)
  137. mask = paddle.unsqueeze(mask, axis=2)
  138. mask = paddle.expand_as(mask, feat)
  139. mask.stop_gradient = True
  140. feat = paddle.masked_select(feat, mask > 0)
  141. feat = paddle.reshape(feat, shape=[-1, feat_c])
  142. feat = F.normalize(feat)
  143. feat = self.emb_scale * feat
  144. logit = self.classifier(feat)
  145. target.stop_gradient = True
  146. loss = self.reid_loss(logit, target)
  147. valid = (target != self.reid_loss.ignore_index)
  148. valid.stop_gradient = True
  149. count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
  150. count.stop_gradient = True
  151. if count > 0:
  152. loss = loss / count
  153. return loss
  154. def get_mc_loss(self, feat, inputs):
  155. # feat.shape = [bs, ch_emb, h, w]
  156. assert 'cls_id_map' in inputs and 'cls_tr_ids' in inputs
  157. index = inputs['index']
  158. mask = inputs['index_mask']
  159. cls_id_map = inputs['cls_id_map'] # [bs, h, w]
  160. cls_tr_ids = inputs['cls_tr_ids'] # [bs, num_classes, h, w]
  161. feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
  162. feat_n, feat_h, feat_w, feat_c = feat.shape
  163. feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
  164. index = paddle.unsqueeze(index, 2)
  165. batch_inds = list()
  166. for i in range(feat_n):
  167. batch_ind = paddle.full(
  168. shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
  169. batch_inds.append(batch_ind)
  170. batch_inds = paddle.concat(batch_inds, axis=0)
  171. index = paddle.concat(x=[batch_inds, index], axis=2)
  172. feat = paddle.gather_nd(feat, index=index)
  173. mask = paddle.unsqueeze(mask, axis=2)
  174. mask = paddle.expand_as(mask, feat)
  175. mask.stop_gradient = True
  176. feat = paddle.masked_select(feat, mask > 0)
  177. feat = paddle.reshape(feat, shape=[-1, feat_c])
  178. reid_losses = 0
  179. for cls_id, id_num in self.num_identities_dict.items():
  180. # target
  181. cur_cls_tr_ids = paddle.reshape(
  182. cls_tr_ids[:, cls_id, :, :], shape=[feat_n, -1]) # [bs, h*w]
  183. cls_id_target = paddle.gather_nd(cur_cls_tr_ids, index=index)
  184. mask = inputs['index_mask']
  185. cls_id_target = paddle.masked_select(cls_id_target, mask > 0)
  186. cls_id_target.stop_gradient = True
  187. # feat
  188. cls_id_feat = self.emb_scale_dict[str(cls_id)] * F.normalize(feat)
  189. cls_id_pred = self.classifiers[str(cls_id)](cls_id_feat)
  190. loss = self.reid_loss(cls_id_pred, cls_id_target)
  191. valid = (cls_id_target != self.reid_loss.ignore_index)
  192. valid.stop_gradient = True
  193. count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
  194. count.stop_gradient = True
  195. if count > 0:
  196. loss = loss / count
  197. reid_losses += loss
  198. return reid_losses