transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import sys
  13. import cv2
  14. import copy
  15. import math
  16. import pyclipper
  17. import numpy as np
  18. from PIL import Image
  19. from shapely.geometry import Polygon
  20. from ....utils import logging
  21. from ...base.predictor.io.writers import ImageWriter
  22. from ...base.predictor.io.readers import ImageReader
  23. from ...base.predictor import BaseTransform
  24. from .keys import TextDetKeys as K
  25. __all__ = [
  26. 'DetResizeForTest', 'NormalizeImage', 'DBPostProcess', 'SaveTextDetResults',
  27. 'PrintResult'
  28. ]
  29. class DetResizeForTest(BaseTransform):
  30. """ DetResizeForTest """
  31. def __init__(self, **kwargs):
  32. super(DetResizeForTest, self).__init__()
  33. self.resize_type = 0
  34. self.keep_ratio = False
  35. if 'image_shape' in kwargs:
  36. self.image_shape = kwargs['image_shape']
  37. self.resize_type = 1
  38. if 'keep_ratio' in kwargs:
  39. self.keep_ratio = kwargs['keep_ratio']
  40. elif 'limit_side_len' in kwargs:
  41. self.limit_side_len = kwargs['limit_side_len']
  42. self.limit_type = kwargs.get('limit_type', 'min')
  43. elif 'resize_long' in kwargs:
  44. self.resize_type = 2
  45. self.resize_long = kwargs.get('resize_long', 960)
  46. else:
  47. self.limit_side_len = 736
  48. self.limit_type = 'min'
  49. def apply(self, data):
  50. """ apply """
  51. img = data[K.IMAGE]
  52. src_h, src_w, _ = img.shape
  53. if sum([src_h, src_w]) < 64:
  54. img = self.image_padding(img)
  55. if self.resize_type == 0:
  56. # img, shape = self.resize_image_type0(img)
  57. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  58. elif self.resize_type == 2:
  59. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  60. else:
  61. # img, shape = self.resize_image_type1(img)
  62. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  63. data[K.IMAGE] = img
  64. data[K.SHAPE] = np.array([src_h, src_w, ratio_h, ratio_w])
  65. return data
  66. @classmethod
  67. def get_input_keys(cls):
  68. """ get input keys """
  69. return [K.IMAGE]
  70. @classmethod
  71. def get_output_keys(cls):
  72. """ get output keys """
  73. return [K.IMAGE, K.SHAPE]
  74. def image_padding(self, im, value=0):
  75. """ padding image """
  76. h, w, c = im.shape
  77. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  78. im_pad[:h, :w, :] = im
  79. return im_pad
  80. def resize_image_type1(self, img):
  81. """ resize the image """
  82. resize_h, resize_w = self.image_shape
  83. ori_h, ori_w = img.shape[:2] # (h, w, c)
  84. if self.keep_ratio is True:
  85. resize_w = ori_w * resize_h / ori_h
  86. N = math.ceil(resize_w / 32)
  87. resize_w = N * 32
  88. ratio_h = float(resize_h) / ori_h
  89. ratio_w = float(resize_w) / ori_w
  90. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  91. # return img, np.array([ori_h, ori_w])
  92. return img, [ratio_h, ratio_w]
  93. def resize_image_type0(self, img):
  94. """
  95. resize image to a size multiple of 32 which is required by the network
  96. args:
  97. img(array): array with shape [h, w, c]
  98. return(tuple):
  99. img, (ratio_h, ratio_w)
  100. """
  101. limit_side_len = self.limit_side_len
  102. h, w, c = img.shape
  103. # limit the max side
  104. if self.limit_type == 'max':
  105. if max(h, w) > limit_side_len:
  106. if h > w:
  107. ratio = float(limit_side_len) / h
  108. else:
  109. ratio = float(limit_side_len) / w
  110. else:
  111. ratio = 1.
  112. elif self.limit_type == 'min':
  113. if min(h, w) < limit_side_len:
  114. if h < w:
  115. ratio = float(limit_side_len) / h
  116. else:
  117. ratio = float(limit_side_len) / w
  118. else:
  119. ratio = 1.
  120. elif self.limit_type == 'resize_long':
  121. ratio = float(limit_side_len) / max(h, w)
  122. else:
  123. raise Exception('not support limit type, image ')
  124. resize_h = int(h * ratio)
  125. resize_w = int(w * ratio)
  126. resize_h = max(int(round(resize_h / 32) * 32), 32)
  127. resize_w = max(int(round(resize_w / 32) * 32), 32)
  128. try:
  129. if int(resize_w) <= 0 or int(resize_h) <= 0:
  130. return None, (None, None)
  131. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  132. except:
  133. logging.info(img.shape, resize_w, resize_h)
  134. sys.exit(0)
  135. ratio_h = resize_h / float(h)
  136. ratio_w = resize_w / float(w)
  137. return img, [ratio_h, ratio_w]
  138. def resize_image_type2(self, img):
  139. """ resize image size """
  140. h, w, _ = img.shape
  141. resize_w = w
  142. resize_h = h
  143. if resize_h > resize_w:
  144. ratio = float(self.resize_long) / resize_h
  145. else:
  146. ratio = float(self.resize_long) / resize_w
  147. resize_h = int(resize_h * ratio)
  148. resize_w = int(resize_w * ratio)
  149. max_stride = 128
  150. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  151. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  152. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  153. ratio_h = resize_h / float(h)
  154. ratio_w = resize_w / float(w)
  155. return img, [ratio_h, ratio_w]
  156. class NormalizeImage(BaseTransform):
  157. """ normalize image such as substract mean, divide std
  158. """
  159. def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
  160. if isinstance(scale, str):
  161. scale = eval(scale)
  162. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  163. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  164. std = std if std is not None else [0.229, 0.224, 0.225]
  165. shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
  166. self.mean = np.array(mean).reshape(shape).astype('float32')
  167. self.std = np.array(std).reshape(shape).astype('float32')
  168. def apply(self, data):
  169. """ apply """
  170. img = data[K.IMAGE]
  171. from PIL import Image
  172. if isinstance(img, Image.Image):
  173. img = np.array(img)
  174. assert isinstance(img,
  175. np.ndarray), "invalid input 'img' in NormalizeImage"
  176. data[K.IMAGE] = (
  177. img.astype('float32') * self.scale - self.mean) / self.std
  178. return data
  179. @classmethod
  180. def get_input_keys(cls):
  181. """ get input keys """
  182. return [K.IMAGE]
  183. @classmethod
  184. def get_output_keys(cls):
  185. """ get output keys """
  186. return [K.IMAGE]
  187. class DBPostProcess(BaseTransform):
  188. """
  189. The post process for Differentiable Binarization (DB).
  190. """
  191. def __init__(self,
  192. thresh=0.3,
  193. box_thresh=0.7,
  194. max_candidates=1000,
  195. unclip_ratio=2.0,
  196. use_dilation=False,
  197. score_mode="fast",
  198. box_type='quad',
  199. **kwargs):
  200. self.thresh = thresh
  201. self.box_thresh = box_thresh
  202. self.max_candidates = max_candidates
  203. self.unclip_ratio = unclip_ratio
  204. self.min_size = 3
  205. self.score_mode = score_mode
  206. self.box_type = box_type
  207. assert score_mode in [
  208. "slow", "fast"
  209. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  210. self.dilation_kernel = None if not use_dilation else np.array([[1, 1],
  211. [1, 1]])
  212. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  213. """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} """
  214. bitmap = _bitmap
  215. height, width = bitmap.shape
  216. boxes = []
  217. scores = []
  218. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  219. cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  220. for contour in contours[:self.max_candidates]:
  221. epsilon = 0.002 * cv2.arcLength(contour, True)
  222. approx = cv2.approxPolyDP(contour, epsilon, True)
  223. points = approx.reshape((-1, 2))
  224. if points.shape[0] < 4:
  225. continue
  226. score = self.box_score_fast(pred, points.reshape(-1, 2))
  227. if self.box_thresh > score:
  228. continue
  229. if points.shape[0] > 2:
  230. box = self.unclip(points, self.unclip_ratio)
  231. if len(box) > 1:
  232. continue
  233. else:
  234. continue
  235. box = box.reshape(-1, 2)
  236. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  237. if sside < self.min_size + 2:
  238. continue
  239. box = np.array(box)
  240. box[:, 0] = np.clip(
  241. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  242. box[:, 1] = np.clip(
  243. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  244. boxes.append(box.tolist())
  245. scores.append(score)
  246. return boxes, scores
  247. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  248. """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} """
  249. bitmap = _bitmap
  250. height, width = bitmap.shape
  251. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  252. cv2.CHAIN_APPROX_SIMPLE)
  253. if len(outs) == 3:
  254. img, contours, _ = outs[0], outs[1], outs[2]
  255. elif len(outs) == 2:
  256. contours, _ = outs[0], outs[1]
  257. num_contours = min(len(contours), self.max_candidates)
  258. boxes = []
  259. scores = []
  260. for index in range(num_contours):
  261. contour = contours[index]
  262. points, sside = self.get_mini_boxes(contour)
  263. if sside < self.min_size:
  264. continue
  265. points = np.array(points)
  266. if self.score_mode == "fast":
  267. score = self.box_score_fast(pred, points.reshape(-1, 2))
  268. else:
  269. score = self.box_score_slow(pred, contour)
  270. if self.box_thresh > score:
  271. continue
  272. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  273. box, sside = self.get_mini_boxes(box)
  274. if sside < self.min_size + 2:
  275. continue
  276. box = np.array(box)
  277. box[:, 0] = np.clip(
  278. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  279. box[:, 1] = np.clip(
  280. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  281. boxes.append(box.astype(np.int16))
  282. scores.append(score)
  283. return np.array(boxes, dtype=np.int16), scores
  284. def unclip(self, box, unclip_ratio):
  285. """ unclip """
  286. poly = Polygon(box)
  287. distance = poly.area * unclip_ratio / poly.length
  288. offset = pyclipper.PyclipperOffset()
  289. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  290. expanded = np.array(offset.Execute(distance))
  291. return expanded
  292. def get_mini_boxes(self, contour):
  293. """ get mini boxes """
  294. bounding_box = cv2.minAreaRect(contour)
  295. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  296. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  297. if points[1][1] > points[0][1]:
  298. index_1 = 0
  299. index_4 = 1
  300. else:
  301. index_1 = 1
  302. index_4 = 0
  303. if points[3][1] > points[2][1]:
  304. index_2 = 2
  305. index_3 = 3
  306. else:
  307. index_2 = 3
  308. index_3 = 2
  309. box = [
  310. points[index_1], points[index_2], points[index_3], points[index_4]
  311. ]
  312. return box, min(bounding_box[1])
  313. def box_score_fast(self, bitmap, _box):
  314. """ box_score_fast: use bbox mean score as the mean score """
  315. h, w = bitmap.shape[:2]
  316. box = _box.copy()
  317. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  318. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  319. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  320. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  321. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  322. box[:, 0] = box[:, 0] - xmin
  323. box[:, 1] = box[:, 1] - ymin
  324. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  325. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  326. def box_score_slow(self, bitmap, contour):
  327. """ box_score_slow: use polyon mean score as the mean score """
  328. h, w = bitmap.shape[:2]
  329. contour = contour.copy()
  330. contour = np.reshape(contour, (-1, 2))
  331. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  332. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  333. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  334. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  335. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  336. contour[:, 0] = contour[:, 0] - xmin
  337. contour[:, 1] = contour[:, 1] - ymin
  338. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  339. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  340. def apply(self, data):
  341. """ apply """
  342. pred = data[K.PROB_MAP]
  343. shape_list = [data[K.SHAPE]]
  344. pred = pred[0][:, 0, :, :]
  345. segmentation = pred > self.thresh
  346. boxes_batch = []
  347. for batch_index in range(pred.shape[0]):
  348. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  349. if self.dilation_kernel is not None:
  350. mask = cv2.dilate(
  351. np.array(segmentation[batch_index]).astype(np.uint8),
  352. self.dilation_kernel)
  353. else:
  354. mask = segmentation[batch_index]
  355. if self.box_type == 'poly':
  356. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  357. mask, src_w, src_h)
  358. elif self.box_type == 'quad':
  359. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  360. src_w, src_h)
  361. else:
  362. raise ValueError("box_type can only be one of ['quad', 'poly']")
  363. data[K.DT_POLYS] = boxes
  364. data[K.DT_SCORES] = scores
  365. return data
  366. @classmethod
  367. def get_input_keys(cls):
  368. """ get input keys """
  369. return [K.PROB_MAP]
  370. @classmethod
  371. def get_output_keys(cls):
  372. """ get output keys """
  373. return [K.DT_POLYS, K.DT_SCORES]
  374. class CropByPolys(BaseTransform):
  375. """Crop Image by Polys
  376. """
  377. def __init__(self, det_box_type='quad'):
  378. super().__init__()
  379. self.det_box_type = det_box_type
  380. def apply(self, data):
  381. """ apply """
  382. ori_im = data[K.ORI_IM]
  383. # TODO
  384. # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
  385. dt_boxes = np.array(data[K.DT_POLYS])
  386. img_crop_list = []
  387. for bno in range(len(dt_boxes)):
  388. tmp_box = copy.deepcopy(dt_boxes[bno])
  389. if self.det_box_type == "quad":
  390. img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
  391. else:
  392. img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
  393. img_crop_list.append(img_crop)
  394. data[K.SUB_IMGS] = img_crop_list
  395. return data
  396. @classmethod
  397. def get_input_keys(cls):
  398. """ get input keys """
  399. return [K.IM_PATH, K.DT_POLYS]
  400. @classmethod
  401. def get_output_keys(cls):
  402. """ get output keys """
  403. return [K.SUB_IMGS]
  404. def sorted_boxes(self, dt_boxes):
  405. """
  406. Sort text boxes in order from top to bottom, left to right
  407. args:
  408. dt_boxes(array):detected text boxes with shape [4, 2]
  409. return:
  410. sorted boxes(array) with shape [4, 2]
  411. """
  412. dt_boxes = np.array(dt_boxes)
  413. num_boxes = dt_boxes.shape[0]
  414. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  415. _boxes = list(sorted_boxes)
  416. for i in range(num_boxes - 1):
  417. for j in range(i, -1, -1):
  418. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  419. _boxes[j + 1][0][0] < _boxes[j][0][0]):
  420. tmp = _boxes[j]
  421. _boxes[j] = _boxes[j + 1]
  422. _boxes[j + 1] = tmp
  423. else:
  424. break
  425. return _boxes
  426. def get_minarea_rect_crop(self, img, points):
  427. """get_minarea_rect_crop
  428. """
  429. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  430. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  431. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  432. if points[1][1] > points[0][1]:
  433. index_a = 0
  434. index_d = 1
  435. else:
  436. index_a = 1
  437. index_d = 0
  438. if points[3][1] > points[2][1]:
  439. index_b = 2
  440. index_c = 3
  441. else:
  442. index_b = 3
  443. index_c = 2
  444. box = [
  445. points[index_a], points[index_b], points[index_c], points[index_d]
  446. ]
  447. crop_img = self.get_rotate_crop_image(img, np.array(box))
  448. return crop_img
  449. def get_rotate_crop_image(self, img, points):
  450. """
  451. img_height, img_width = img.shape[0:2]
  452. left = int(np.min(points[:, 0]))
  453. right = int(np.max(points[:, 0]))
  454. top = int(np.min(points[:, 1]))
  455. bottom = int(np.max(points[:, 1]))
  456. img_crop = img[top:bottom, left:right, :].copy()
  457. points[:, 0] = points[:, 0] - left
  458. points[:, 1] = points[:, 1] - top
  459. """
  460. assert len(points) == 4, "shape of points must be 4*2"
  461. img_crop_width = int(
  462. max(
  463. np.linalg.norm(points[0] - points[1]),
  464. np.linalg.norm(points[2] - points[3])))
  465. img_crop_height = int(
  466. max(
  467. np.linalg.norm(points[0] - points[3]),
  468. np.linalg.norm(points[1] - points[2])))
  469. pts_std = np.float32([
  470. [0, 0],
  471. [img_crop_width, 0],
  472. [img_crop_width, img_crop_height],
  473. [0, img_crop_height],
  474. ])
  475. M = cv2.getPerspectiveTransform(points, pts_std)
  476. dst_img = cv2.warpPerspective(
  477. img,
  478. M,
  479. (img_crop_width, img_crop_height),
  480. borderMode=cv2.BORDER_REPLICATE,
  481. flags=cv2.INTER_CUBIC, )
  482. dst_img_height, dst_img_width = dst_img.shape[0:2]
  483. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  484. dst_img = np.rot90(dst_img)
  485. return dst_img
  486. class SaveTextDetResults(BaseTransform):
  487. """ Save Text Det Results """
  488. def __init__(self, save_dir):
  489. super().__init__()
  490. self.save_dir = save_dir
  491. # We use pillow backend to save both numpy arrays and PIL Image objects
  492. self._writer = ImageWriter(backend='opencv')
  493. def apply(self, data):
  494. """ apply """
  495. if self.save_dir is None:
  496. logging.warning(
  497. "The `save_dir` has been set to None, so the text detection result won't to be saved."
  498. )
  499. return data
  500. fn = os.path.basename(data['input_path'])
  501. save_path = os.path.join(self.save_dir, fn)
  502. bbox_res = data[K.DT_POLYS]
  503. vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
  504. self._writer.write(save_path, vis_img)
  505. return data
  506. @classmethod
  507. def get_input_keys(cls):
  508. """ get input keys """
  509. return [K.IM_PATH, K.DT_POLYS, K.DT_SCORES]
  510. @classmethod
  511. def get_output_keys(cls):
  512. """ get output keys """
  513. return []
  514. def draw_rectangle(self, img_path, boxes):
  515. """ draw rectangle """
  516. boxes = np.array(boxes)
  517. img = cv2.imread(img_path)
  518. img_show = img.copy()
  519. for box in boxes.astype(int):
  520. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  521. cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
  522. return img_show
  523. class PrintResult(BaseTransform):
  524. """ Print Result Transform """
  525. def apply(self, data):
  526. """ apply """
  527. logging.info("The prediction result is:")
  528. logging.info(data[K.DT_POLYS])
  529. return data
  530. @classmethod
  531. def get_input_keys(cls):
  532. """ get input keys """
  533. return [K.DT_SCORES]
  534. @classmethod
  535. def get_output_keys(cls):
  536. """ get output keys """
  537. return []
  538. # DT_SCORES = 'dt_scores'
  539. # DT_POLYS = 'dt_polys'