keypoint_metrics.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. from collections import defaultdict, OrderedDict
  17. import numpy as np
  18. from pycocotools.coco import COCO
  19. from pycocotools.cocoeval import COCOeval
  20. from ..modeling.keypoint_utils import oks_nms
  21. from scipy.io import loadmat, savemat
  22. from paddlex.ppdet.utils.logger import setup_logger
  23. logger = setup_logger(__name__)
  24. __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
  25. class KeyPointTopDownCOCOEval(object):
  26. '''
  27. Adapted from
  28. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  29. Copyright (c) Microsoft, under the MIT License.
  30. '''
  31. def __init__(self,
  32. anno_file,
  33. num_samples,
  34. num_joints,
  35. output_eval,
  36. iou_type='keypoints',
  37. in_vis_thre=0.2,
  38. oks_thre=0.9,
  39. save_prediction_only=False):
  40. super(KeyPointTopDownCOCOEval, self).__init__()
  41. self.coco = COCO(anno_file)
  42. self.num_samples = num_samples
  43. self.num_joints = num_joints
  44. self.iou_type = iou_type
  45. self.in_vis_thre = in_vis_thre
  46. self.oks_thre = oks_thre
  47. self.output_eval = output_eval
  48. self.res_file = os.path.join(output_eval, "keypoints_results.json")
  49. self.save_prediction_only = save_prediction_only
  50. self.reset()
  51. def reset(self):
  52. self.results = {
  53. 'all_preds': np.zeros(
  54. (self.num_samples, self.num_joints, 3), dtype=np.float32),
  55. 'all_boxes': np.zeros((self.num_samples, 6)),
  56. 'image_path': []
  57. }
  58. self.eval_results = {}
  59. self.idx = 0
  60. def update(self, inputs, outputs):
  61. kpts, _ = outputs['keypoint'][0]
  62. num_images = inputs['image'].shape[0]
  63. self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
  64. 3] = kpts[:, :, 0:3]
  65. self.results['all_boxes'][self.idx:self.idx + num_images, 0:
  66. 2] = inputs['center'].numpy()[:, 0:2]
  67. self.results['all_boxes'][self.idx:self.idx + num_images, 2:
  68. 4] = inputs['scale'].numpy()[:, 0:2]
  69. self.results['all_boxes'][self.idx:self.idx + num_images, 4] = np.prod(
  70. inputs['scale'].numpy() * 200, 1)
  71. self.results['all_boxes'][self.idx:self.idx + num_images,
  72. 5] = np.squeeze(inputs['score'].numpy())
  73. self.results['image_path'].extend(inputs['im_id'].numpy())
  74. self.idx += num_images
  75. def _write_coco_keypoint_results(self, keypoints):
  76. data_pack = [{
  77. 'cat_id': 1,
  78. 'cls': 'person',
  79. 'ann_type': 'keypoints',
  80. 'keypoints': keypoints
  81. }]
  82. results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
  83. if not os.path.exists(self.output_eval):
  84. os.makedirs(self.output_eval)
  85. with open(self.res_file, 'w') as f:
  86. json.dump(results, f, sort_keys=True, indent=4)
  87. logger.info(f'The keypoint result is saved to {self.res_file}.')
  88. try:
  89. json.load(open(self.res_file))
  90. except Exception:
  91. content = []
  92. with open(self.res_file, 'r') as f:
  93. for line in f:
  94. content.append(line)
  95. content[-1] = ']'
  96. with open(self.res_file, 'w') as f:
  97. for c in content:
  98. f.write(c)
  99. def _coco_keypoint_results_one_category_kernel(self, data_pack):
  100. cat_id = data_pack['cat_id']
  101. keypoints = data_pack['keypoints']
  102. cat_results = []
  103. for img_kpts in keypoints:
  104. if len(img_kpts) == 0:
  105. continue
  106. _key_points = np.array(
  107. [img_kpts[k]['keypoints'] for k in range(len(img_kpts))])
  108. _key_points = _key_points.reshape(_key_points.shape[0], -1)
  109. result = [{
  110. 'image_id': img_kpts[k]['image'],
  111. 'category_id': cat_id,
  112. 'keypoints': _key_points[k].tolist(),
  113. 'score': img_kpts[k]['score'],
  114. 'center': list(img_kpts[k]['center']),
  115. 'scale': list(img_kpts[k]['scale'])
  116. } for k in range(len(img_kpts))]
  117. cat_results.extend(result)
  118. return cat_results
  119. def get_final_results(self, preds, all_boxes, img_path):
  120. _kpts = []
  121. for idx, kpt in enumerate(preds):
  122. _kpts.append({
  123. 'keypoints': kpt,
  124. 'center': all_boxes[idx][0:2],
  125. 'scale': all_boxes[idx][2:4],
  126. 'area': all_boxes[idx][4],
  127. 'score': all_boxes[idx][5],
  128. 'image': int(img_path[idx])
  129. })
  130. # image x person x (keypoints)
  131. kpts = defaultdict(list)
  132. for kpt in _kpts:
  133. kpts[kpt['image']].append(kpt)
  134. # rescoring and oks nms
  135. num_joints = preds.shape[1]
  136. in_vis_thre = self.in_vis_thre
  137. oks_thre = self.oks_thre
  138. oks_nmsed_kpts = []
  139. for img in kpts.keys():
  140. img_kpts = kpts[img]
  141. for n_p in img_kpts:
  142. box_score = n_p['score']
  143. kpt_score = 0
  144. valid_num = 0
  145. for n_jt in range(0, num_joints):
  146. t_s = n_p['keypoints'][n_jt][2]
  147. if t_s > in_vis_thre:
  148. kpt_score = kpt_score + t_s
  149. valid_num = valid_num + 1
  150. if valid_num != 0:
  151. kpt_score = kpt_score / valid_num
  152. # rescoring
  153. n_p['score'] = kpt_score * box_score
  154. keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
  155. oks_thre)
  156. if len(keep) == 0:
  157. oks_nmsed_kpts.append(img_kpts)
  158. else:
  159. oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
  160. self._write_coco_keypoint_results(oks_nmsed_kpts)
  161. def accumulate(self):
  162. self.get_final_results(self.results['all_preds'],
  163. self.results['all_boxes'],
  164. self.results['image_path'])
  165. if self.save_prediction_only:
  166. logger.info(f'The keypoint result is saved to {self.res_file} '
  167. 'and do not evaluate the mAP.')
  168. return
  169. coco_dt = self.coco.loadRes(self.res_file)
  170. coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
  171. coco_eval.params.useSegm = None
  172. coco_eval.evaluate()
  173. coco_eval.accumulate()
  174. coco_eval.summarize()
  175. keypoint_stats = []
  176. for ind in range(len(coco_eval.stats)):
  177. keypoint_stats.append((coco_eval.stats[ind]))
  178. self.eval_results['keypoint'] = keypoint_stats
  179. def log(self):
  180. if self.save_prediction_only:
  181. return
  182. stats_names = [
  183. 'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
  184. 'AR .75', 'AR (M)', 'AR (L)'
  185. ]
  186. num_values = len(stats_names)
  187. print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |')
  188. print('|---' * (num_values + 1) + '|')
  189. print(' '.join([
  190. '| {:.3f}'.format(value) for value in self.eval_results['keypoint']
  191. ]) + ' |')
  192. def get_results(self):
  193. return self.eval_results
  194. class KeyPointTopDownMPIIEval(object):
  195. def __init__(self,
  196. anno_file,
  197. num_samples,
  198. num_joints,
  199. output_eval,
  200. oks_thre=0.9,
  201. save_prediction_only=False):
  202. super(KeyPointTopDownMPIIEval, self).__init__()
  203. self.ann_file = anno_file
  204. self.res_file = os.path.join(output_eval, "keypoints_results.json")
  205. self.save_prediction_only = save_prediction_only
  206. self.reset()
  207. def reset(self):
  208. self.results = []
  209. self.eval_results = {}
  210. self.idx = 0
  211. def update(self, inputs, outputs):
  212. kpts, _ = outputs['keypoint'][0]
  213. num_images = inputs['image'].shape[0]
  214. results = {}
  215. results['preds'] = kpts[:, :, 0:3]
  216. results['boxes'] = np.zeros((num_images, 6))
  217. results['boxes'][:, 0:2] = inputs['center'].numpy()[:, 0:2]
  218. results['boxes'][:, 2:4] = inputs['scale'].numpy()[:, 0:2]
  219. results['boxes'][:, 4] = np.prod(inputs['scale'].numpy() * 200, 1)
  220. results['boxes'][:, 5] = np.squeeze(inputs['score'].numpy())
  221. results['image_path'] = inputs['image_file']
  222. self.results.append(results)
  223. def accumulate(self):
  224. self._mpii_keypoint_results_save()
  225. if self.save_prediction_only:
  226. logger.info(f'The keypoint result is saved to {self.res_file} '
  227. 'and do not evaluate the mAP.')
  228. return
  229. self.eval_results = self.evaluate(self.results)
  230. def _mpii_keypoint_results_save(self):
  231. results = []
  232. for res in self.results:
  233. if len(res) == 0:
  234. continue
  235. result = [{
  236. 'preds': res['preds'][k].tolist(),
  237. 'boxes': res['boxes'][k].tolist(),
  238. 'image_path': res['image_path'][k],
  239. } for k in range(len(res))]
  240. results.extend(result)
  241. with open(self.res_file, 'w') as f:
  242. json.dump(results, f, sort_keys=True, indent=4)
  243. logger.info(f'The keypoint result is saved to {self.res_file}.')
  244. def log(self):
  245. if self.save_prediction_only:
  246. return
  247. for item, value in self.eval_results.items():
  248. print("{} : {}".format(item, value))
  249. def get_results(self):
  250. return self.eval_results
  251. def evaluate(self, outputs, savepath=None):
  252. """Evaluate PCKh for MPII dataset. Adapted from
  253. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  254. Copyright (c) Microsoft, under the MIT License.
  255. Args:
  256. outputs(list(preds, boxes)):
  257. * preds (np.ndarray[N,K,3]): The first two dimensions are
  258. coordinates, score is the third dimension of the array.
  259. * boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
  260. , scale[1],area, score]
  261. Returns:
  262. dict: PCKh for each joint
  263. """
  264. kpts = []
  265. for output in outputs:
  266. preds = output['preds']
  267. batch_size = preds.shape[0]
  268. for i in range(batch_size):
  269. kpts.append({'keypoints': preds[i]})
  270. preds = np.stack([kpt['keypoints'] for kpt in kpts])
  271. # convert 0-based index to 1-based index,
  272. # and get the first two dimensions.
  273. preds = preds[..., :2] + 1.0
  274. if savepath is not None:
  275. pred_file = os.path.join(savepath, 'pred.mat')
  276. savemat(pred_file, mdict={'preds': preds})
  277. SC_BIAS = 0.6
  278. threshold = 0.5
  279. gt_file = os.path.join(
  280. os.path.dirname(self.ann_file), 'mpii_gt_val.mat')
  281. gt_dict = loadmat(gt_file)
  282. dataset_joints = gt_dict['dataset_joints']
  283. jnt_missing = gt_dict['jnt_missing']
  284. pos_gt_src = gt_dict['pos_gt_src']
  285. headboxes_src = gt_dict['headboxes_src']
  286. pos_pred_src = np.transpose(preds, [1, 2, 0])
  287. head = np.where(dataset_joints == 'head')[1][0]
  288. lsho = np.where(dataset_joints == 'lsho')[1][0]
  289. lelb = np.where(dataset_joints == 'lelb')[1][0]
  290. lwri = np.where(dataset_joints == 'lwri')[1][0]
  291. lhip = np.where(dataset_joints == 'lhip')[1][0]
  292. lkne = np.where(dataset_joints == 'lkne')[1][0]
  293. lank = np.where(dataset_joints == 'lank')[1][0]
  294. rsho = np.where(dataset_joints == 'rsho')[1][0]
  295. relb = np.where(dataset_joints == 'relb')[1][0]
  296. rwri = np.where(dataset_joints == 'rwri')[1][0]
  297. rkne = np.where(dataset_joints == 'rkne')[1][0]
  298. rank = np.where(dataset_joints == 'rank')[1][0]
  299. rhip = np.where(dataset_joints == 'rhip')[1][0]
  300. jnt_visible = 1 - jnt_missing
  301. uv_error = pos_pred_src - pos_gt_src
  302. uv_err = np.linalg.norm(uv_error, axis=1)
  303. headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
  304. headsizes = np.linalg.norm(headsizes, axis=0)
  305. headsizes *= SC_BIAS
  306. scale = headsizes * np.ones((len(uv_err), 1), dtype=np.float32)
  307. scaled_uv_err = uv_err / scale
  308. scaled_uv_err = scaled_uv_err * jnt_visible
  309. jnt_count = np.sum(jnt_visible, axis=1)
  310. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  311. PCKh = 100. * np.sum(less_than_threshold, axis=1) / jnt_count
  312. # save
  313. rng = np.arange(0, 0.5 + 0.01, 0.01)
  314. pckAll = np.zeros((len(rng), 16), dtype=np.float32)
  315. for r, threshold in enumerate(rng):
  316. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  317. pckAll[r, :] = 100. * np.sum(less_than_threshold,
  318. axis=1) / jnt_count
  319. PCKh = np.ma.array(PCKh, mask=False)
  320. PCKh.mask[6:8] = True
  321. jnt_count = np.ma.array(jnt_count, mask=False)
  322. jnt_count.mask[6:8] = True
  323. jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
  324. name_value = [ #noqa
  325. ('Head', PCKh[head]),
  326. ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
  327. ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
  328. ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
  329. ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
  330. ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
  331. ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
  332. ('PCKh', np.sum(PCKh * jnt_ratio)),
  333. ('PCKh@0.1', np.sum(pckAll[11, :] * jnt_ratio))
  334. ]
  335. name_value = OrderedDict(name_value)
  336. return name_value
  337. def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
  338. """sort kpts and remove the repeated ones."""
  339. kpts = sorted(kpts, key=lambda x: x[key])
  340. num = len(kpts)
  341. for i in range(num - 1, 0, -1):
  342. if kpts[i][key] == kpts[i - 1][key]:
  343. del kpts[i]
  344. return kpts