keypoint_operators.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # function:
  15. # operators to process sample,
  16. # eg: decode/resize/crop image
  17. from __future__ import absolute_import
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. import cv2
  23. import numpy as np
  24. import math
  25. import copy
  26. from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform, get_warp_matrix
  27. from paddlex.ppdet.core.workspace import serializable
  28. from paddlex.ppdet.utils.logger import setup_logger
  29. logger = setup_logger(__name__)
  30. registered_ops = []
  31. __all__ = [
  32. 'RandomAffine',
  33. 'KeyPointFlip',
  34. 'TagGenerate',
  35. 'ToHeatmaps',
  36. 'NormalizePermute',
  37. 'EvalAffine',
  38. 'RandomFlipHalfBodyTransform',
  39. 'TopDownAffine',
  40. 'ToHeatmapsTopDown',
  41. 'ToHeatmapsTopDown_DARK',
  42. 'ToHeatmapsTopDown_UDP',
  43. 'TopDownEvalAffine',
  44. 'AugmentationbyInformantionDropping',
  45. ]
  46. def register_keypointop(cls):
  47. return serializable(cls)
  48. @register_keypointop
  49. class KeyPointFlip(object):
  50. """Get the fliped image by flip_prob. flip the coords also
  51. the left coords and right coords should exchange while flip, for the right keypoint will be left keypoint after image fliped
  52. Args:
  53. flip_permutation (list[17]): the left-right exchange order list corresponding to [0,1,2,...,16]
  54. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  55. flip_prob (float): the ratio whether to flip the image
  56. records(dict): the dict contained the image, mask and coords
  57. Returns:
  58. records(dict): contain the image, mask and coords after tranformed
  59. """
  60. def __init__(self, flip_permutation, hmsize, flip_prob=0.5):
  61. super(KeyPointFlip, self).__init__()
  62. assert isinstance(flip_permutation, Sequence)
  63. self.flip_permutation = flip_permutation
  64. self.flip_prob = flip_prob
  65. self.hmsize = hmsize
  66. def __call__(self, records):
  67. image = records['image']
  68. kpts_lst = records['joints']
  69. mask_lst = records['mask']
  70. flip = np.random.random() < self.flip_prob
  71. if flip:
  72. image = image[:, ::-1]
  73. for idx, hmsize in enumerate(self.hmsize):
  74. if len(mask_lst) > idx:
  75. mask_lst[idx] = mask_lst[idx][:, ::-1]
  76. if kpts_lst[idx].ndim == 3:
  77. kpts_lst[idx] = kpts_lst[idx][:, self.flip_permutation]
  78. else:
  79. kpts_lst[idx] = kpts_lst[idx][self.flip_permutation]
  80. kpts_lst[idx][..., 0] = hmsize - kpts_lst[idx][..., 0]
  81. kpts_lst[idx] = kpts_lst[idx].astype(np.int64)
  82. kpts_lst[idx][kpts_lst[idx][..., 0] >= hmsize, 2] = 0
  83. kpts_lst[idx][kpts_lst[idx][..., 1] >= hmsize, 2] = 0
  84. kpts_lst[idx][kpts_lst[idx][..., 0] < 0, 2] = 0
  85. kpts_lst[idx][kpts_lst[idx][..., 1] < 0, 2] = 0
  86. records['image'] = image
  87. records['joints'] = kpts_lst
  88. records['mask'] = mask_lst
  89. return records
  90. @register_keypointop
  91. class RandomAffine(object):
  92. """apply affine transform to image, mask and coords
  93. to achieve the rotate, scale and shift effect for training image
  94. Args:
  95. max_degree (float): the max abslute rotate degree to apply, transform range is [-max_degree, max_degree]
  96. max_scale (list[2]): the scale range to apply, transform range is [min, max]
  97. max_shift (float): the max abslute shift ratio to apply, transform range is [-max_shift*imagesize, max_shift*imagesize]
  98. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  99. trainsize (int): the standard length used to train, the 'scale_type' of [h,w] will be resize to trainsize for standard
  100. scale_type (str): the length of [h,w] to used for trainsize, chosed between 'short' and 'long'
  101. records(dict): the dict contained the image, mask and coords
  102. Returns:
  103. records(dict): contain the image, mask and coords after tranformed
  104. """
  105. def __init__(self,
  106. max_degree=30,
  107. scale=[0.75, 1.5],
  108. max_shift=0.2,
  109. hmsize=[128, 256],
  110. trainsize=512,
  111. scale_type='short'):
  112. super(RandomAffine, self).__init__()
  113. self.max_degree = max_degree
  114. self.min_scale = scale[0]
  115. self.max_scale = scale[1]
  116. self.max_shift = max_shift
  117. self.hmsize = hmsize
  118. self.trainsize = trainsize
  119. self.scale_type = scale_type
  120. def _get_affine_matrix(self, center, scale, res, rot=0):
  121. """Generate transformation matrix."""
  122. h = scale
  123. t = np.zeros((3, 3), dtype=np.float32)
  124. t[0, 0] = float(res[1]) / h
  125. t[1, 1] = float(res[0]) / h
  126. t[0, 2] = res[1] * (-float(center[0]) / h + .5)
  127. t[1, 2] = res[0] * (-float(center[1]) / h + .5)
  128. t[2, 2] = 1
  129. if rot != 0:
  130. rot = -rot # To match direction of rotation from cropping
  131. rot_mat = np.zeros((3, 3), dtype=np.float32)
  132. rot_rad = rot * np.pi / 180
  133. sn, cs = np.sin(rot_rad), np.cos(rot_rad)
  134. rot_mat[0, :2] = [cs, -sn]
  135. rot_mat[1, :2] = [sn, cs]
  136. rot_mat[2, 2] = 1
  137. # Need to rotate around center
  138. t_mat = np.eye(3)
  139. t_mat[0, 2] = -res[1] / 2
  140. t_mat[1, 2] = -res[0] / 2
  141. t_inv = t_mat.copy()
  142. t_inv[:2, 2] *= -1
  143. t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
  144. return t
  145. def __call__(self, records):
  146. image = records['image']
  147. keypoints = records['joints']
  148. heatmap_mask = records['mask']
  149. degree = (np.random.random() * 2 - 1) * self.max_degree
  150. shape = np.array(image.shape[:2][::-1])
  151. center = center = np.array((np.array(shape) / 2))
  152. aug_scale = np.random.random() * (self.max_scale - self.min_scale
  153. ) + self.min_scale
  154. if self.scale_type == 'long':
  155. scale = max(shape[0], shape[1]) / 1.0
  156. elif self.scale_type == 'short':
  157. scale = min(shape[0], shape[1]) / 1.0
  158. else:
  159. raise ValueError('Unknown scale type: {}'.format(self.scale_type))
  160. roi_size = aug_scale * scale
  161. dx = int(0)
  162. dy = int(0)
  163. if self.max_shift > 0:
  164. dx = np.random.randint(-self.max_shift * roi_size,
  165. self.max_shift * roi_size)
  166. dy = np.random.randint(-self.max_shift * roi_size,
  167. self.max_shift * roi_size)
  168. center += np.array([dx, dy])
  169. input_size = 2 * center
  170. keypoints[..., :2] *= shape
  171. heatmap_mask *= 255
  172. kpts_lst = []
  173. mask_lst = []
  174. image_affine_mat = self._get_affine_matrix(
  175. center, roi_size, (self.trainsize, self.trainsize), degree)[:2]
  176. image = cv2.warpAffine(
  177. image,
  178. image_affine_mat, (self.trainsize, self.trainsize),
  179. flags=cv2.INTER_LINEAR)
  180. for hmsize in self.hmsize:
  181. kpts = copy.deepcopy(keypoints)
  182. mask_affine_mat = self._get_affine_matrix(
  183. center, roi_size, (hmsize, hmsize), degree)[:2]
  184. if heatmap_mask is not None:
  185. mask = cv2.warpAffine(heatmap_mask, mask_affine_mat,
  186. (hmsize, hmsize))
  187. mask = ((mask / 255) > 0.5).astype(np.float32)
  188. kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(),
  189. mask_affine_mat)
  190. kpts[np.trunc(kpts[..., 0]) >= hmsize, 2] = 0
  191. kpts[np.trunc(kpts[..., 1]) >= hmsize, 2] = 0
  192. kpts[np.trunc(kpts[..., 0]) < 0, 2] = 0
  193. kpts[np.trunc(kpts[..., 1]) < 0, 2] = 0
  194. kpts_lst.append(kpts)
  195. mask_lst.append(mask)
  196. records['image'] = image
  197. records['joints'] = kpts_lst
  198. records['mask'] = mask_lst
  199. return records
  200. @register_keypointop
  201. class EvalAffine(object):
  202. """apply affine transform to image
  203. resize the short of [h,w] to standard size for eval
  204. Args:
  205. size (int): the standard length used to train, the 'short' of [h,w] will be resize to trainsize for standard
  206. records(dict): the dict contained the image, mask and coords
  207. Returns:
  208. records(dict): contain the image, mask and coords after tranformed
  209. """
  210. def __init__(self, size, stride=64):
  211. super(EvalAffine, self).__init__()
  212. self.size = size
  213. self.stride = stride
  214. def __call__(self, records):
  215. image = records['image']
  216. mask = records['mask'] if 'mask' in records else None
  217. s = self.size
  218. h, w, _ = image.shape
  219. trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
  220. image_resized = cv2.warpAffine(image, trans, size_resized)
  221. if mask is not None:
  222. mask = cv2.warpAffine(mask, trans, size_resized)
  223. records['mask'] = mask
  224. if 'joints' in records:
  225. del records['joints']
  226. records['image'] = image_resized
  227. return records
  228. @register_keypointop
  229. class NormalizePermute(object):
  230. def __init__(self,
  231. mean=[123.675, 116.28, 103.53],
  232. std=[58.395, 57.120, 57.375],
  233. is_scale=True):
  234. super(NormalizePermute, self).__init__()
  235. self.mean = mean
  236. self.std = std
  237. self.is_scale = is_scale
  238. def __call__(self, records):
  239. image = records['image']
  240. image = image.astype(np.float32)
  241. if self.is_scale:
  242. image /= 255.
  243. image = image.transpose((2, 0, 1))
  244. mean = np.array(self.mean, dtype=np.float32)
  245. std = np.array(self.std, dtype=np.float32)
  246. invstd = 1. / std
  247. for v, m, s in zip(image, mean, invstd):
  248. v.__isub__(m).__imul__(s)
  249. records['image'] = image
  250. return records
  251. @register_keypointop
  252. class TagGenerate(object):
  253. """record gt coords for aeloss to sample coords value in tagmaps
  254. Args:
  255. num_joints (int): the keypoint numbers of dataset to train
  256. num_people (int): maxmum people to support for sample aeloss
  257. records(dict): the dict contained the image, mask and coords
  258. Returns:
  259. records(dict): contain the gt coords used in tagmap
  260. """
  261. def __init__(self, num_joints, max_people=30):
  262. super(TagGenerate, self).__init__()
  263. self.max_people = max_people
  264. self.num_joints = num_joints
  265. def __call__(self, records):
  266. kpts_lst = records['joints']
  267. kpts = kpts_lst[0]
  268. tagmap = np.zeros(
  269. (self.max_people, self.num_joints, 4), dtype=np.int64)
  270. inds = np.where(kpts[..., 2] > 0)
  271. p, j = inds[0], inds[1]
  272. visible = kpts[inds]
  273. # tagmap is [p, j, 3], where last dim is j, y, x
  274. tagmap[p, j, 0] = j
  275. tagmap[p, j, 1] = visible[..., 1] # y
  276. tagmap[p, j, 2] = visible[..., 0] # x
  277. tagmap[p, j, 3] = 1
  278. records['tagmap'] = tagmap
  279. del records['joints']
  280. return records
  281. @register_keypointop
  282. class ToHeatmaps(object):
  283. """to generate the gaussin heatmaps of keypoint for heatmap loss
  284. Args:
  285. num_joints (int): the keypoint numbers of dataset to train
  286. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  287. sigma (float): the std of gaussin kernel genereted
  288. records(dict): the dict contained the image, mask and coords
  289. Returns:
  290. records(dict): contain the heatmaps used to heatmaploss
  291. """
  292. def __init__(self, num_joints, hmsize, sigma=None):
  293. super(ToHeatmaps, self).__init__()
  294. self.num_joints = num_joints
  295. self.hmsize = np.array(hmsize)
  296. if sigma is None:
  297. sigma = hmsize[0] // 64
  298. self.sigma = sigma
  299. r = 6 * sigma + 3
  300. x = np.arange(0, r, 1, np.float32)
  301. y = x[:, None]
  302. x0, y0 = 3 * sigma + 1, 3 * sigma + 1
  303. self.gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
  304. def __call__(self, records):
  305. kpts_lst = records['joints']
  306. mask_lst = records['mask']
  307. for idx, hmsize in enumerate(self.hmsize):
  308. mask = mask_lst[idx]
  309. kpts = kpts_lst[idx]
  310. heatmaps = np.zeros((self.num_joints, hmsize, hmsize))
  311. inds = np.where(kpts[..., 2] > 0)
  312. visible = kpts[inds].astype(np.int64)[..., :2]
  313. ul = np.round(visible - 3 * self.sigma - 1)
  314. br = np.round(visible + 3 * self.sigma + 2)
  315. sul = np.maximum(0, -ul)
  316. sbr = np.minimum(hmsize, br) - ul
  317. dul = np.clip(ul, 0, hmsize - 1)
  318. dbr = np.clip(br, 0, hmsize)
  319. for i in range(len(visible)):
  320. if visible[i][0] < 0 or visible[i][1] < 0 or visible[i][
  321. 0] >= hmsize or visible[i][1] >= hmsize:
  322. continue
  323. dx1, dy1 = dul[i]
  324. dx2, dy2 = dbr[i]
  325. sx1, sy1 = sul[i]
  326. sx2, sy2 = sbr[i]
  327. heatmaps[inds[1][i], dy1:dy2, dx1:dx2] = np.maximum(
  328. self.gaussian[sy1:sy2, sx1:sx2],
  329. heatmaps[inds[1][i], dy1:dy2, dx1:dx2])
  330. records['heatmap_gt{}x'.format(idx + 1)] = heatmaps
  331. records['mask_{}x'.format(idx + 1)] = mask
  332. del records['mask']
  333. return records
  334. @register_keypointop
  335. class RandomFlipHalfBodyTransform(object):
  336. """apply data augment to image and coords
  337. to achieve the flip, scale, rotate and half body transform effect for training image
  338. Args:
  339. trainsize (list):[w, h], Image target size
  340. upper_body_ids (list): The upper body joint ids
  341. flip_pairs (list): The left-right joints exchange order list
  342. pixel_std (int): The pixel std of the scale
  343. scale (float): The scale factor to transform the image
  344. rot (int): The rotate factor to transform the image
  345. num_joints_half_body (int): The joints threshold of the half body transform
  346. prob_half_body (float): The threshold of the half body transform
  347. flip (bool): Whether to flip the image
  348. Returns:
  349. records(dict): contain the image and coords after tranformed
  350. """
  351. def __init__(self,
  352. trainsize,
  353. upper_body_ids,
  354. flip_pairs,
  355. pixel_std,
  356. scale=0.35,
  357. rot=40,
  358. num_joints_half_body=8,
  359. prob_half_body=0.3,
  360. flip=True,
  361. rot_prob=0.6):
  362. super(RandomFlipHalfBodyTransform, self).__init__()
  363. self.trainsize = trainsize
  364. self.upper_body_ids = upper_body_ids
  365. self.flip_pairs = flip_pairs
  366. self.pixel_std = pixel_std
  367. self.scale = scale
  368. self.rot = rot
  369. self.num_joints_half_body = num_joints_half_body
  370. self.prob_half_body = prob_half_body
  371. self.flip = flip
  372. self.aspect_ratio = trainsize[0] * 1.0 / trainsize[1]
  373. self.rot_prob = rot_prob
  374. def halfbody_transform(self, joints, joints_vis):
  375. upper_joints = []
  376. lower_joints = []
  377. for joint_id in range(joints.shape[0]):
  378. if joints_vis[joint_id][0] > 0:
  379. if joint_id in self.upper_body_ids:
  380. upper_joints.append(joints[joint_id])
  381. else:
  382. lower_joints.append(joints[joint_id])
  383. if np.random.randn() < 0.5 and len(upper_joints) > 2:
  384. selected_joints = upper_joints
  385. else:
  386. selected_joints = lower_joints if len(
  387. lower_joints) > 2 else upper_joints
  388. if len(selected_joints) < 2:
  389. return None, None
  390. selected_joints = np.array(selected_joints, dtype=np.float32)
  391. center = selected_joints.mean(axis=0)[:2]
  392. left_top = np.amin(selected_joints, axis=0)
  393. right_bottom = np.amax(selected_joints, axis=0)
  394. w = right_bottom[0] - left_top[0]
  395. h = right_bottom[1] - left_top[1]
  396. if w > self.aspect_ratio * h:
  397. h = w * 1.0 / self.aspect_ratio
  398. elif w < self.aspect_ratio * h:
  399. w = h * self.aspect_ratio
  400. scale = np.array(
  401. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  402. dtype=np.float32)
  403. scale = scale * 1.5
  404. return center, scale
  405. def flip_joints(self, joints, joints_vis, width, matched_parts):
  406. joints[:, 0] = width - joints[:, 0] - 1
  407. for pair in matched_parts:
  408. joints[pair[0], :], joints[pair[1], :] = \
  409. joints[pair[1], :], joints[pair[0], :].copy()
  410. joints_vis[pair[0], :], joints_vis[pair[1], :] = \
  411. joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
  412. return joints * joints_vis, joints_vis
  413. def __call__(self, records):
  414. image = records['image']
  415. joints = records['joints']
  416. joints_vis = records['joints_vis']
  417. c = records['center']
  418. s = records['scale']
  419. r = 0
  420. if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body and
  421. np.random.rand() < self.prob_half_body):
  422. c_half_body, s_half_body = self.halfbody_transform(joints,
  423. joints_vis)
  424. if c_half_body is not None and s_half_body is not None:
  425. c, s = c_half_body, s_half_body
  426. sf = self.scale
  427. rf = self.rot
  428. s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
  429. r = np.clip(np.random.randn() * rf, -rf * 2,
  430. rf * 2) if np.random.random() <= self.rot_prob else 0
  431. if self.flip and np.random.random() <= 0.5:
  432. image = image[:, ::-1, :]
  433. joints, joints_vis = self.flip_joints(
  434. joints, joints_vis, image.shape[1], self.flip_pairs)
  435. c[0] = image.shape[1] - c[0] - 1
  436. records['image'] = image
  437. records['joints'] = joints
  438. records['joints_vis'] = joints_vis
  439. records['center'] = c
  440. records['scale'] = s
  441. records['rotate'] = r
  442. return records
  443. @register_keypointop
  444. class AugmentationbyInformantionDropping(object):
  445. """AID: Augmentation by Informantion Dropping. Please refer
  446. to https://arxiv.org/abs/2008.07139
  447. Args:
  448. prob_cutout (float): The probability of the Cutout augmentation.
  449. offset_factor (float): Offset factor of cutout center.
  450. num_patch (int): Number of patches to be cutout.
  451. records(dict): the dict contained the image and coords
  452. Returns:
  453. records (dict): contain the image and coords after tranformed
  454. """
  455. def __init__(self,
  456. trainsize,
  457. prob_cutout=0.0,
  458. offset_factor=0.2,
  459. num_patch=1):
  460. self.prob_cutout = prob_cutout
  461. self.offset_factor = offset_factor
  462. self.num_patch = num_patch
  463. self.trainsize = trainsize
  464. def _cutout(self, img, joints, joints_vis):
  465. height, width, _ = img.shape
  466. img = img.reshape((height * width, -1))
  467. feat_x_int = np.arange(0, width)
  468. feat_y_int = np.arange(0, height)
  469. feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
  470. feat_x_int = feat_x_int.reshape((-1, ))
  471. feat_y_int = feat_y_int.reshape((-1, ))
  472. for _ in range(self.num_patch):
  473. vis_idx, _ = np.where(joints_vis > 0)
  474. occlusion_joint_id = np.random.choice(vis_idx)
  475. center = joints[occlusion_joint_id, 0:2]
  476. offset = np.random.randn(2) * self.trainsize[
  477. 0] * self.offset_factor
  478. center = center + offset
  479. radius = np.random.uniform(0.1, 0.2) * self.trainsize[0]
  480. x_offset = (center[0] - feat_x_int) / radius
  481. y_offset = (center[1] - feat_y_int) / radius
  482. dis = x_offset**2 + y_offset**2
  483. keep_pos = np.where((dis <= 1) & (dis >= 0))[0]
  484. img[keep_pos, :] = 0
  485. img = img.reshape((height, width, -1))
  486. return img
  487. def __call__(self, records):
  488. img = records['image']
  489. joints = records['joints']
  490. joints_vis = records['joints_vis']
  491. if np.random.rand() < self.prob_cutout:
  492. img = self._cutout(img, joints, joints_vis)
  493. records['image'] = img
  494. return records
  495. @register_keypointop
  496. class TopDownAffine(object):
  497. """apply affine transform to image and coords
  498. Args:
  499. trainsize (list): [w, h], the standard size used to train
  500. use_udp (bool): whether to use Unbiased Data Processing.
  501. records(dict): the dict contained the image and coords
  502. Returns:
  503. records (dict): contain the image and coords after tranformed
  504. """
  505. def __init__(self, trainsize, use_udp=False):
  506. self.trainsize = trainsize
  507. self.use_udp = use_udp
  508. def __call__(self, records):
  509. image = records['image']
  510. joints = records['joints']
  511. joints_vis = records['joints_vis']
  512. rot = records['rotate'] if "rotate" in records else 0
  513. if self.use_udp:
  514. trans = get_warp_matrix(
  515. rot, records['center'] * 2.0,
  516. [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0],
  517. records['scale'] * 200.0)
  518. image = cv2.warpAffine(
  519. image,
  520. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  521. flags=cv2.INTER_LINEAR)
  522. joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(), trans)
  523. else:
  524. trans = get_affine_transform(records['center'], records['scale'] *
  525. 200, rot, self.trainsize)
  526. image = cv2.warpAffine(
  527. image,
  528. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  529. flags=cv2.INTER_LINEAR)
  530. for i in range(joints.shape[0]):
  531. if joints_vis[i, 0] > 0.0:
  532. joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
  533. records['image'] = image
  534. records['joints'] = joints
  535. return records
  536. @register_keypointop
  537. class TopDownEvalAffine(object):
  538. """apply affine transform to image and coords
  539. Args:
  540. trainsize (list): [w, h], the standard size used to train
  541. use_udp (bool): whether to use Unbiased Data Processing.
  542. records(dict): the dict contained the image and coords
  543. Returns:
  544. records (dict): contain the image and coords after tranformed
  545. """
  546. def __init__(self, trainsize, use_udp=False):
  547. self.trainsize = trainsize
  548. self.use_udp = use_udp
  549. def __call__(self, records):
  550. image = records['image']
  551. rot = 0
  552. imshape = records['im_shape'][::-1]
  553. center = imshape / 2.
  554. scale = imshape
  555. if self.use_udp:
  556. trans = get_warp_matrix(
  557. rot, center * 2.0,
  558. [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
  559. image = cv2.warpAffine(
  560. image,
  561. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  562. flags=cv2.INTER_LINEAR)
  563. else:
  564. trans = get_affine_transform(center, scale, rot, self.trainsize)
  565. image = cv2.warpAffine(
  566. image,
  567. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  568. flags=cv2.INTER_LINEAR)
  569. records['image'] = image
  570. return records
  571. @register_keypointop
  572. class ToHeatmapsTopDown(object):
  573. """to generate the gaussin heatmaps of keypoint for heatmap loss
  574. Args:
  575. hmsize (list): [w, h] output heatmap's size
  576. sigma (float): the std of gaussin kernel genereted
  577. records(dict): the dict contained the image and coords
  578. Returns:
  579. records (dict): contain the heatmaps used to heatmaploss
  580. """
  581. def __init__(self, hmsize, sigma):
  582. super(ToHeatmapsTopDown, self).__init__()
  583. self.hmsize = np.array(hmsize)
  584. self.sigma = sigma
  585. def __call__(self, records):
  586. """refer to
  587. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  588. Copyright (c) Microsoft, under the MIT License.
  589. """
  590. joints = records['joints']
  591. joints_vis = records['joints_vis']
  592. num_joints = joints.shape[0]
  593. image_size = np.array(
  594. [records['image'].shape[1], records['image'].shape[0]])
  595. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  596. target_weight[:, 0] = joints_vis[:, 0]
  597. target = np.zeros(
  598. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  599. tmp_size = self.sigma * 3
  600. feat_stride = image_size / self.hmsize
  601. for joint_id in range(num_joints):
  602. mu_x = int(joints[joint_id][0] + 0.5) / feat_stride[0]
  603. mu_y = int(joints[joint_id][1] + 0.5) / feat_stride[1]
  604. # Check that any part of the gaussian is in-bounds
  605. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  606. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  607. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  608. 0] < 0 or br[1] < 0:
  609. # If not, just return the image as is
  610. target_weight[joint_id] = 0
  611. continue
  612. # # Generate gaussian
  613. size = 2 * tmp_size + 1
  614. x = np.arange(0, size, 1, np.float32)
  615. y = x[:, np.newaxis]
  616. x0 = y0 = size // 2
  617. # The gaussian is not normalized, we want the center value to equal 1
  618. g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2))
  619. # Usable gaussian range
  620. g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0]
  621. g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1]
  622. # Image range
  623. img_x = max(0, ul[0]), min(br[0], self.hmsize[0])
  624. img_y = max(0, ul[1]), min(br[1], self.hmsize[1])
  625. v = target_weight[joint_id]
  626. if v > 0.5:
  627. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[
  628. 0]:g_y[1], g_x[0]:g_x[1]]
  629. records['target'] = target
  630. records['target_weight'] = target_weight
  631. del records['joints'], records['joints_vis']
  632. return records
  633. @register_keypointop
  634. class ToHeatmapsTopDown_DARK(object):
  635. """to generate the gaussin heatmaps of keypoint for heatmap loss
  636. Args:
  637. hmsize (list): [w, h] output heatmap's size
  638. sigma (float): the std of gaussin kernel genereted
  639. records(dict): the dict contained the image and coords
  640. Returns:
  641. records (dict): contain the heatmaps used to heatmaploss
  642. """
  643. def __init__(self, hmsize, sigma):
  644. super(ToHeatmapsTopDown_DARK, self).__init__()
  645. self.hmsize = np.array(hmsize)
  646. self.sigma = sigma
  647. def __call__(self, records):
  648. joints = records['joints']
  649. joints_vis = records['joints_vis']
  650. num_joints = joints.shape[0]
  651. image_size = np.array(
  652. [records['image'].shape[1], records['image'].shape[0]])
  653. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  654. target_weight[:, 0] = joints_vis[:, 0]
  655. target = np.zeros(
  656. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  657. tmp_size = self.sigma * 3
  658. feat_stride = image_size / self.hmsize
  659. for joint_id in range(num_joints):
  660. mu_x = joints[joint_id][0] / feat_stride[0]
  661. mu_y = joints[joint_id][1] / feat_stride[1]
  662. # Check that any part of the gaussian is in-bounds
  663. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  664. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  665. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  666. 0] < 0 or br[1] < 0:
  667. # If not, just return the image as is
  668. target_weight[joint_id] = 0
  669. continue
  670. x = np.arange(0, self.hmsize[0], 1, np.float32)
  671. y = np.arange(0, self.hmsize[1], 1, np.float32)
  672. y = y[:, np.newaxis]
  673. v = target_weight[joint_id]
  674. if v > 0.5:
  675. target[joint_id] = np.exp(-(
  676. (x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
  677. records['target'] = target
  678. records['target_weight'] = target_weight
  679. del records['joints'], records['joints_vis']
  680. return records
  681. @register_keypointop
  682. class ToHeatmapsTopDown_UDP(object):
  683. """This code is based on:
  684. https://github.com/HuangJunJie2017/UDP-Pose/blob/master/deep-high-resolution-net.pytorch/lib/dataset/JointsDataset.py
  685. to generate the gaussian heatmaps of keypoint for heatmap loss.
  686. ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing
  687. for Human Pose Estimation (CVPR 2020).
  688. Args:
  689. hmsize (list): [w, h] output heatmap's size
  690. sigma (float): the std of gaussin kernel genereted
  691. records(dict): the dict contained the image and coords
  692. Returns:
  693. records (dict): contain the heatmaps used to heatmaploss
  694. """
  695. def __init__(self, hmsize, sigma):
  696. super(ToHeatmapsTopDown_UDP, self).__init__()
  697. self.hmsize = np.array(hmsize)
  698. self.sigma = sigma
  699. def __call__(self, records):
  700. joints = records['joints']
  701. joints_vis = records['joints_vis']
  702. num_joints = joints.shape[0]
  703. image_size = np.array(
  704. [records['image'].shape[1], records['image'].shape[0]])
  705. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  706. target_weight[:, 0] = joints_vis[:, 0]
  707. target = np.zeros(
  708. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  709. tmp_size = self.sigma * 3
  710. size = 2 * tmp_size + 1
  711. x = np.arange(0, size, 1, np.float32)
  712. y = x[:, None]
  713. feat_stride = (image_size - 1.0) / (self.hmsize - 1.0)
  714. for joint_id in range(num_joints):
  715. mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
  716. mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
  717. # Check that any part of the gaussian is in-bounds
  718. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  719. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  720. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  721. 0] < 0 or br[1] < 0:
  722. # If not, just return the image as is
  723. target_weight[joint_id] = 0
  724. continue
  725. mu_x_ac = joints[joint_id][0] / feat_stride[0]
  726. mu_y_ac = joints[joint_id][1] / feat_stride[1]
  727. x0 = y0 = size // 2
  728. x0 += mu_x_ac - mu_x
  729. y0 += mu_y_ac - mu_y
  730. g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2))
  731. # Usable gaussian range
  732. g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0]
  733. g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1]
  734. # Image range
  735. img_x = max(0, ul[0]), min(br[0], self.hmsize[0])
  736. img_y = max(0, ul[1]), min(br[1], self.hmsize[1])
  737. v = target_weight[joint_id]
  738. if v > 0.5:
  739. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[
  740. 0]:g_y[1], g_x[0]:g_x[1]]
  741. records['target'] = target
  742. records['target_weight'] = target_weight
  743. del records['joints'], records['joints_vis']
  744. return records