deepsort_matching.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/tree/master/deep_sort
  16. """
  17. import numpy as np
  18. from scipy.optimize import linear_sum_assignment
  19. from ..motion import kalman_filter
  20. INFTY_COST = 1e+5
  21. __all__ = [
  22. 'iou_1toN',
  23. 'iou_cost',
  24. '_nn_euclidean_distance',
  25. '_nn_cosine_distance',
  26. 'NearestNeighborDistanceMetric',
  27. 'min_cost_matching',
  28. 'matching_cascade',
  29. 'gate_cost_matrix',
  30. ]
  31. def iou_1toN(bbox, candidates):
  32. """
  33. Computer intersection over union (IoU) by one box to N candidates.
  34. Args:
  35. bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
  36. candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
  37. same format as `bbox`.
  38. Returns:
  39. ious (ndarray): The intersection over union in [0, 1] between the `bbox`
  40. and each candidate. A higher score means a larger fraction of the
  41. `bbox` is occluded by the candidate.
  42. """
  43. bbox_tl = bbox[:2]
  44. bbox_br = bbox[:2] + bbox[2:]
  45. candidates_tl = candidates[:, :2]
  46. candidates_br = candidates[:, :2] + candidates[:, 2:]
  47. tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
  48. np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
  49. br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
  50. np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
  51. wh = np.maximum(0., br - tl)
  52. area_intersection = wh.prod(axis=1)
  53. area_bbox = bbox[2:].prod()
  54. area_candidates = candidates[:, 2:].prod(axis=1)
  55. ious = area_intersection / (
  56. area_bbox + area_candidates - area_intersection)
  57. return ious
  58. def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
  59. """
  60. IoU distance metric.
  61. Args:
  62. tracks (list[Track]): A list of tracks.
  63. detections (list[Detection]): A list of detections.
  64. track_indices (Optional[list[int]]): A list of indices to tracks that
  65. should be matched. Defaults to all `tracks`.
  66. detection_indices (Optional[list[int]]): A list of indices to detections
  67. that should be matched. Defaults to all `detections`.
  68. Returns:
  69. cost_matrix (ndarray): A cost matrix of shape len(track_indices),
  70. len(detection_indices) where entry (i, j) is
  71. `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
  72. """
  73. if track_indices is None:
  74. track_indices = np.arange(len(tracks))
  75. if detection_indices is None:
  76. detection_indices = np.arange(len(detections))
  77. cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
  78. for row, track_idx in enumerate(track_indices):
  79. if tracks[track_idx].time_since_update > 1:
  80. cost_matrix[row, :] = 1e+5
  81. continue
  82. bbox = tracks[track_idx].to_tlwh()
  83. candidates = np.asarray(
  84. [detections[i].tlwh for i in detection_indices])
  85. cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
  86. return cost_matrix
  87. def _nn_euclidean_distance(s, q):
  88. """
  89. Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
  90. Args:
  91. s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
  92. q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
  93. Returns:
  94. distances (ndarray): A vector of length M that contains for each entry in `q` the
  95. smallest Euclidean distance to a sample in `s`.
  96. """
  97. s, q = np.asarray(s), np.asarray(q)
  98. if len(s) == 0 or len(q) == 0:
  99. return np.zeros((len(s), len(q)))
  100. s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
  101. distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
  102. distances = np.clip(distances, 0., float(np.inf))
  103. return np.maximum(0.0, distances.min(axis=0))
  104. def _nn_cosine_distance(s, q):
  105. """
  106. Compute pair-wise cosine distance between points in `s` and `q`.
  107. Args:
  108. s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
  109. q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
  110. Returns:
  111. distances (ndarray): A vector of length M that contains for each entry in `q` the
  112. smallest Euclidean distance to a sample in `s`.
  113. """
  114. s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
  115. q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
  116. distances = 1. - np.dot(s, q.T)
  117. return distances.min(axis=0)
  118. class NearestNeighborDistanceMetric(object):
  119. """
  120. A nearest neighbor distance metric that, for each target, returns
  121. the closest distance to any sample that has been observed so far.
  122. Args:
  123. metric (str): Either "euclidean" or "cosine".
  124. matching_threshold (float): The matching threshold. Samples with larger
  125. distance are considered an invalid match.
  126. budget (Optional[int]): If not None, fix samples per class to at most
  127. this number. Removes the oldest samples when the budget is reached.
  128. Attributes:
  129. samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
  130. identities to the list of samples that have been observed so far.
  131. """
  132. def __init__(self, metric, matching_threshold, budget=None):
  133. if metric == "euclidean":
  134. self._metric = _nn_euclidean_distance
  135. elif metric == "cosine":
  136. self._metric = _nn_cosine_distance
  137. else:
  138. raise ValueError(
  139. "Invalid metric; must be either 'euclidean' or 'cosine'")
  140. self.matching_threshold = matching_threshold
  141. self.budget = budget
  142. self.samples = {}
  143. def partial_fit(self, features, targets, active_targets):
  144. """
  145. Update the distance metric with new data.
  146. Args:
  147. features (ndarray): An NxM matrix of N features of dimensionality M.
  148. targets (ndarray): An integer array of associated target identities.
  149. active_targets (List[int]): A list of targets that are currently
  150. present in the scene.
  151. """
  152. for feature, target in zip(features, targets):
  153. self.samples.setdefault(target, []).append(feature)
  154. if self.budget is not None:
  155. self.samples[target] = self.samples[target][-self.budget:]
  156. self.samples = {k: self.samples[k] for k in active_targets}
  157. def distance(self, features, targets):
  158. """
  159. Compute distance between features and targets.
  160. Args:
  161. features (ndarray): An NxM matrix of N features of dimensionality M.
  162. targets (list[int]): A list of targets to match the given `features` against.
  163. Returns:
  164. cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
  165. where element (i, j) contains the closest squared distance between
  166. `targets[i]` and `features[j]`.
  167. """
  168. cost_matrix = np.zeros((len(targets), len(features)))
  169. for i, target in enumerate(targets):
  170. cost_matrix[i, :] = self._metric(self.samples[target], features)
  171. return cost_matrix
  172. def min_cost_matching(distance_metric,
  173. max_distance,
  174. tracks,
  175. detections,
  176. track_indices=None,
  177. detection_indices=None):
  178. """
  179. Solve linear assignment problem.
  180. Args:
  181. distance_metric :
  182. Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  183. The distance metric is given a list of tracks and detections as
  184. well as a list of N track indices and M detection indices. The
  185. metric should return the NxM dimensional cost matrix, where element
  186. (i, j) is the association cost between the i-th track in the given
  187. track indices and the j-th detection in the given detection_indices.
  188. max_distance (float): Gating threshold. Associations with cost larger
  189. than this value are disregarded.
  190. tracks (list[Track]): A list of predicted tracks at the current time
  191. step.
  192. detections (list[Detection]): A list of detections at the current time
  193. step.
  194. track_indices (list[int]): List of track indices that maps rows in
  195. `cost_matrix` to tracks in `tracks`.
  196. detection_indices (List[int]): List of detection indices that maps
  197. columns in `cost_matrix` to detections in `detections`.
  198. Returns:
  199. A tuple (List[(int, int)], List[int], List[int]) with the following
  200. three entries:
  201. * A list of matched track and detection indices.
  202. * A list of unmatched track indices.
  203. * A list of unmatched detection indices.
  204. """
  205. if track_indices is None:
  206. track_indices = np.arange(len(tracks))
  207. if detection_indices is None:
  208. detection_indices = np.arange(len(detections))
  209. if len(detection_indices) == 0 or len(track_indices) == 0:
  210. return [], track_indices, detection_indices # Nothing to match.
  211. cost_matrix = distance_metric(tracks, detections, track_indices,
  212. detection_indices)
  213. cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
  214. indices = linear_sum_assignment(cost_matrix)
  215. matches, unmatched_tracks, unmatched_detections = [], [], []
  216. for col, detection_idx in enumerate(detection_indices):
  217. if col not in indices[1]:
  218. unmatched_detections.append(detection_idx)
  219. for row, track_idx in enumerate(track_indices):
  220. if row not in indices[0]:
  221. unmatched_tracks.append(track_idx)
  222. for row, col in zip(indices[0], indices[1]):
  223. track_idx = track_indices[row]
  224. detection_idx = detection_indices[col]
  225. if cost_matrix[row, col] > max_distance:
  226. unmatched_tracks.append(track_idx)
  227. unmatched_detections.append(detection_idx)
  228. else:
  229. matches.append((track_idx, detection_idx))
  230. return matches, unmatched_tracks, unmatched_detections
  231. def matching_cascade(distance_metric,
  232. max_distance,
  233. cascade_depth,
  234. tracks,
  235. detections,
  236. track_indices=None,
  237. detection_indices=None):
  238. """
  239. Run matching cascade.
  240. Args:
  241. distance_metric :
  242. Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  243. The distance metric is given a list of tracks and detections as
  244. well as a list of N track indices and M detection indices. The
  245. metric should return the NxM dimensional cost matrix, where element
  246. (i, j) is the association cost between the i-th track in the given
  247. track indices and the j-th detection in the given detection_indices.
  248. max_distance (float): Gating threshold. Associations with cost larger
  249. than this value are disregarded.
  250. cascade_depth (int): The cascade depth, should be se to the maximum
  251. track age.
  252. tracks (list[Track]): A list of predicted tracks at the current time
  253. step.
  254. detections (list[Detection]): A list of detections at the current time
  255. step.
  256. track_indices (list[int]): List of track indices that maps rows in
  257. `cost_matrix` to tracks in `tracks`.
  258. detection_indices (List[int]): List of detection indices that maps
  259. columns in `cost_matrix` to detections in `detections`.
  260. Returns:
  261. A tuple (List[(int, int)], List[int], List[int]) with the following
  262. three entries:
  263. * A list of matched track and detection indices.
  264. * A list of unmatched track indices.
  265. * A list of unmatched detection indices.
  266. """
  267. if track_indices is None:
  268. track_indices = list(range(len(tracks)))
  269. if detection_indices is None:
  270. detection_indices = list(range(len(detections)))
  271. unmatched_detections = detection_indices
  272. matches = []
  273. for level in range(cascade_depth):
  274. if len(unmatched_detections) == 0: # No detections left
  275. break
  276. track_indices_l = [
  277. k for k in track_indices
  278. if tracks[k].time_since_update == 1 + level
  279. ]
  280. if len(track_indices_l) == 0: # Nothing to match at this level
  281. continue
  282. matches_l, _, unmatched_detections = \
  283. min_cost_matching(
  284. distance_metric, max_distance, tracks, detections,
  285. track_indices_l, unmatched_detections)
  286. matches += matches_l
  287. unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
  288. return matches, unmatched_tracks, unmatched_detections
  289. def gate_cost_matrix(kf,
  290. cost_matrix,
  291. tracks,
  292. detections,
  293. track_indices,
  294. detection_indices,
  295. gated_cost=INFTY_COST,
  296. only_position=False):
  297. """
  298. Invalidate infeasible entries in cost matrix based on the state
  299. distributions obtained by Kalman filtering.
  300. Args:
  301. kf (object): The Kalman filter.
  302. cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
  303. number of track indices and M is the number of detection indices,
  304. such that entry (i, j) is the association cost between
  305. `tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
  306. tracks (list[Track]): A list of predicted tracks at the current time
  307. step.
  308. detections (list[Detection]): A list of detections at the current time
  309. step.
  310. track_indices (List[int]): List of track indices that maps rows in
  311. `cost_matrix` to tracks in `tracks`.
  312. detection_indices (List[int]): List of detection indices that maps
  313. columns in `cost_matrix` to detections in `detections`.
  314. gated_cost (Optional[float]): Entries in the cost matrix corresponding
  315. to infeasible associations are set this value. Defaults to a very
  316. large value.
  317. only_position (Optional[bool]): If True, only the x, y position of the
  318. state distribution is considered during gating. Default False.
  319. """
  320. gating_dim = 2 if only_position else 4
  321. gating_threshold = kalman_filter.chi2inv95[gating_dim]
  322. measurements = np.asarray(
  323. [detections[i].to_xyah() for i in detection_indices])
  324. for row, track_idx in enumerate(track_indices):
  325. track = tracks[track_idx]
  326. gating_distance = kf.gating_distance(track.mean, track.covariance,
  327. measurements, only_position)
  328. cost_matrix[row, gating_distance > gating_threshold] = gated_cost
  329. return cost_matrix