jde_embedding_head.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import numpy as np
  19. import paddle
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle import ParamAttr
  23. from paddle.regularizer import L2Decay
  24. from paddlex.ppdet.core.workspace import register
  25. from paddle.nn.initializer import Normal, Constant
  26. __all__ = ['JDEEmbeddingHead']
  27. class LossParam(nn.Layer):
  28. def __init__(self, init_value=0., use_uncertainy=True):
  29. super(LossParam, self).__init__()
  30. self.loss_param = self.create_parameter(
  31. shape=[1],
  32. attr=ParamAttr(initializer=Constant(value=init_value)),
  33. dtype="float32")
  34. def forward(self, inputs):
  35. out = paddle.exp(-self.loss_param) * inputs + self.loss_param
  36. return out * 0.5
  37. @register
  38. class JDEEmbeddingHead(nn.Layer):
  39. __shared__ = ['num_classes']
  40. __inject__ = ['emb_loss', 'jde_loss']
  41. """
  42. JDEEmbeddingHead
  43. Args:
  44. num_classes(int): Number of classes. Only support one class tracking.
  45. num_identities(int): Number of identities.
  46. anchor_levels(int): Number of anchor levels, same as FPN levels.
  47. anchor_scales(int): Number of anchor scales on each FPN level.
  48. embedding_dim(int): Embedding dimension. Default: 512.
  49. emb_loss(object): Instance of 'JDEEmbeddingLoss'
  50. jde_loss(object): Instance of 'JDELoss'
  51. """
  52. def __init__(
  53. self,
  54. num_classes=1,
  55. num_identities=14455, # dataset.num_identities_dict[0]
  56. anchor_levels=3,
  57. anchor_scales=4,
  58. embedding_dim=512,
  59. emb_loss='JDEEmbeddingLoss',
  60. jde_loss='JDELoss'):
  61. super(JDEEmbeddingHead, self).__init__()
  62. self.num_classes = num_classes
  63. self.num_identities = num_identities
  64. self.anchor_levels = anchor_levels
  65. self.anchor_scales = anchor_scales
  66. self.embedding_dim = embedding_dim
  67. self.emb_loss = emb_loss
  68. self.jde_loss = jde_loss
  69. self.emb_scale = math.sqrt(2) * math.log(
  70. self.num_identities - 1) if self.num_identities > 1 else 1
  71. self.identify_outputs = []
  72. self.loss_params_cls = []
  73. self.loss_params_reg = []
  74. self.loss_params_ide = []
  75. for i in range(self.anchor_levels):
  76. name = 'identify_output.{}'.format(i)
  77. identify_output = self.add_sublayer(
  78. name,
  79. nn.Conv2D(
  80. in_channels=64 * (2**self.anchor_levels) // (2**i),
  81. out_channels=self.embedding_dim,
  82. kernel_size=3,
  83. stride=1,
  84. padding=1,
  85. bias_attr=ParamAttr(regularizer=L2Decay(0.))))
  86. self.identify_outputs.append(identify_output)
  87. loss_p_cls = self.add_sublayer('cls.{}'.format(i),
  88. LossParam(-4.15))
  89. self.loss_params_cls.append(loss_p_cls)
  90. loss_p_reg = self.add_sublayer('reg.{}'.format(i),
  91. LossParam(-4.85))
  92. self.loss_params_reg.append(loss_p_reg)
  93. loss_p_ide = self.add_sublayer('ide.{}'.format(i), LossParam(-2.3))
  94. self.loss_params_ide.append(loss_p_ide)
  95. self.classifier = self.add_sublayer(
  96. 'classifier',
  97. nn.Linear(
  98. self.embedding_dim,
  99. self.num_identities,
  100. weight_attr=ParamAttr(
  101. learning_rate=1., initializer=Normal(
  102. mean=0.0, std=0.01)),
  103. bias_attr=ParamAttr(
  104. learning_rate=2., regularizer=L2Decay(0.))))
  105. def forward(self,
  106. identify_feats,
  107. targets,
  108. loss_confs=None,
  109. loss_boxes=None,
  110. bboxes=None,
  111. boxes_idx=None,
  112. nms_keep_idx=None):
  113. assert self.num_classes == 1, 'JDE only support sindle class MOT.'
  114. assert len(identify_feats) == self.anchor_levels
  115. ide_outs = []
  116. for feat, ide_head in zip(identify_feats, self.identify_outputs):
  117. ide_outs.append(ide_head(feat))
  118. if self.training:
  119. assert len(loss_confs) == len(loss_boxes) == self.anchor_levels
  120. loss_ides = self.emb_loss(ide_outs, targets, self.emb_scale,
  121. self.classifier)
  122. jde_losses = self.jde_loss(
  123. loss_confs, loss_boxes, loss_ides, self.loss_params_cls,
  124. self.loss_params_reg, self.loss_params_ide, targets)
  125. return jde_losses
  126. else:
  127. assert bboxes is not None
  128. assert boxes_idx is not None
  129. assert nms_keep_idx is not None
  130. emb_outs = self.get_emb_outs(ide_outs)
  131. emb_valid = paddle.gather_nd(emb_outs, boxes_idx)
  132. pred_embs = paddle.gather_nd(emb_valid, nms_keep_idx)
  133. input_shape = targets['image'].shape[2:]
  134. # input_shape: [h, w], before data transforms, set in model config
  135. im_shape = targets['im_shape'][0].numpy()
  136. # im_shape: [new_h, new_w], after data transforms
  137. scale_factor = targets['scale_factor'][0].numpy()
  138. bboxes[:, 2:] = self.scale_coords(bboxes[:, 2:], input_shape,
  139. im_shape, scale_factor)
  140. # tlwhs, scores, cls_ids
  141. pred_dets = paddle.concat(
  142. (bboxes[:, 2:], bboxes[:, 1:2], bboxes[:, 0:1]), axis=1)
  143. return pred_dets, pred_embs
  144. def scale_coords(self, coords, input_shape, im_shape, scale_factor):
  145. ratio = scale_factor[0]
  146. pad_w = (input_shape[1] - int(im_shape[1])) / 2
  147. pad_h = (input_shape[0] - int(im_shape[0])) / 2
  148. coords = paddle.cast(coords, 'float32')
  149. coords[:, 0::2] -= pad_w
  150. coords[:, 1::2] -= pad_h
  151. coords[:, 0:4] /= ratio
  152. coords[:, :4] = paddle.clip(
  153. coords[:, :4], min=0, max=coords[:, :4].max())
  154. return coords.round()
  155. def get_emb_and_gt_outs(self, ide_outs, targets):
  156. emb_and_gts = []
  157. for i, p_ide in enumerate(ide_outs):
  158. t_conf = targets['tconf{}'.format(i)]
  159. t_ide = targets['tide{}'.format(i)]
  160. p_ide = p_ide.transpose((0, 2, 3, 1))
  161. p_ide_flatten = paddle.reshape(p_ide, [-1, self.embedding_dim])
  162. mask = t_conf > 0
  163. mask = paddle.cast(mask, dtype="int64")
  164. emb_mask = mask.max(1).flatten()
  165. emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten()
  166. if len(emb_mask_inds) > 0:
  167. t_ide_flatten = paddle.reshape(t_ide.max(1), [-1, 1])
  168. tids = paddle.gather(t_ide_flatten, emb_mask_inds)
  169. embedding = paddle.gather(p_ide_flatten, emb_mask_inds)
  170. embedding = self.emb_scale * F.normalize(embedding)
  171. emb_and_gt = paddle.concat([embedding, tids], axis=1)
  172. emb_and_gts.append(emb_and_gt)
  173. if len(emb_and_gts) > 0:
  174. return paddle.concat(emb_and_gts, axis=0)
  175. else:
  176. return paddle.zeros((1, self.embedding_dim + 1))
  177. def get_emb_outs(self, ide_outs):
  178. emb_outs = []
  179. for i, p_ide in enumerate(ide_outs):
  180. p_ide = p_ide.transpose((0, 2, 3, 1))
  181. p_ide_repeat = paddle.tile(p_ide, [self.anchor_scales, 1, 1, 1])
  182. embedding = F.normalize(p_ide_repeat, axis=-1)
  183. emb = paddle.reshape(embedding, [-1, self.embedding_dim])
  184. emb_outs.append(emb)
  185. if len(emb_outs) > 0:
  186. return paddle.concat(emb_outs, axis=0)
  187. else:
  188. return paddle.zeros((1, self.embedding_dim))