processors.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import sys
  16. from typing import Union
  17. import cv2
  18. import numpy as np
  19. import pyclipper
  20. from ....utils import logging
  21. from ...utils.benchmark import benchmark
  22. @benchmark.timeit
  23. class DetResizeForTest:
  24. """DetResizeForTest"""
  25. def __init__(self, input_shape=None, **kwargs):
  26. super().__init__()
  27. self.resize_type = 0
  28. self.keep_ratio = False
  29. if input_shape is not None:
  30. self.input_shape = input_shape
  31. self.resize_type = 3
  32. elif "image_shape" in kwargs:
  33. self.image_shape = kwargs["image_shape"]
  34. self.resize_type = 1
  35. if "keep_ratio" in kwargs:
  36. self.keep_ratio = kwargs["keep_ratio"]
  37. elif "limit_side_len" in kwargs:
  38. self.limit_side_len = kwargs["limit_side_len"]
  39. self.limit_type = kwargs.get("limit_type", "min")
  40. elif "resize_long" in kwargs:
  41. self.resize_type = 2
  42. self.resize_long = kwargs.get("resize_long", 960)
  43. else:
  44. self.limit_side_len = 736
  45. self.limit_type = "min"
  46. def __call__(
  47. self,
  48. imgs,
  49. limit_side_len: Union[int, None] = None,
  50. limit_type: Union[str, None] = None,
  51. ):
  52. """apply"""
  53. resize_imgs, img_shapes = [], []
  54. for ori_img in imgs:
  55. img, shape = self.resize(ori_img, limit_side_len, limit_type)
  56. resize_imgs.append(img)
  57. img_shapes.append(shape)
  58. return resize_imgs, img_shapes
  59. def resize(
  60. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  61. ):
  62. src_h, src_w, _ = img.shape
  63. if sum([src_h, src_w]) < 64:
  64. img = self.image_padding(img)
  65. if self.resize_type == 0:
  66. # img, shape = self.resize_image_type0(img)
  67. img, [ratio_h, ratio_w] = self.resize_image_type0(
  68. img, limit_side_len, limit_type
  69. )
  70. elif self.resize_type == 2:
  71. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  72. elif self.resize_type == 3:
  73. img, [ratio_h, ratio_w] = self.resize_image_type3(img)
  74. else:
  75. # img, shape = self.resize_image_type1(img)
  76. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  77. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  78. def image_padding(self, im, value=0):
  79. """padding image"""
  80. h, w, c = im.shape
  81. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  82. im_pad[:h, :w, :] = im
  83. return im_pad
  84. def resize_image_type1(self, img):
  85. """resize the image"""
  86. resize_h, resize_w = self.image_shape
  87. ori_h, ori_w = img.shape[:2] # (h, w, c)
  88. if self.keep_ratio is True:
  89. resize_w = ori_w * resize_h / ori_h
  90. N = math.ceil(resize_w / 32)
  91. resize_w = N * 32
  92. ratio_h = float(resize_h) / ori_h
  93. ratio_w = float(resize_w) / ori_w
  94. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  95. # return img, np.array([ori_h, ori_w])
  96. return img, [ratio_h, ratio_w]
  97. def resize_image_type0(
  98. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  99. ):
  100. """
  101. resize image to a size multiple of 32 which is required by the network
  102. args:
  103. img(array): array with shape [h, w, c]
  104. return(tuple):
  105. img, (ratio_h, ratio_w)
  106. """
  107. limit_side_len = limit_side_len or self.limit_side_len
  108. limit_type = limit_type or self.limit_type
  109. h, w, c = img.shape
  110. # limit the max side
  111. if limit_type == "max":
  112. if max(h, w) > limit_side_len:
  113. if h > w:
  114. ratio = float(limit_side_len) / h
  115. else:
  116. ratio = float(limit_side_len) / w
  117. else:
  118. ratio = 1.0
  119. elif limit_type == "min":
  120. if min(h, w) < limit_side_len:
  121. if h < w:
  122. ratio = float(limit_side_len) / h
  123. else:
  124. ratio = float(limit_side_len) / w
  125. else:
  126. ratio = 1.0
  127. elif limit_type == "resize_long":
  128. ratio = float(limit_side_len) / max(h, w)
  129. else:
  130. raise Exception("not support limit type, image ")
  131. resize_h = int(h * ratio)
  132. resize_w = int(w * ratio)
  133. resize_h = max(int(round(resize_h / 32) * 32), 32)
  134. resize_w = max(int(round(resize_w / 32) * 32), 32)
  135. try:
  136. if int(resize_w) <= 0 or int(resize_h) <= 0:
  137. return None, (None, None)
  138. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  139. except:
  140. logging.info(img.shape, resize_w, resize_h)
  141. sys.exit(0)
  142. ratio_h = resize_h / float(h)
  143. ratio_w = resize_w / float(w)
  144. return img, [ratio_h, ratio_w]
  145. def resize_image_type2(self, img):
  146. """resize image size"""
  147. h, w, _ = img.shape
  148. resize_w = w
  149. resize_h = h
  150. if resize_h > resize_w:
  151. ratio = float(self.resize_long) / resize_h
  152. else:
  153. ratio = float(self.resize_long) / resize_w
  154. resize_h = int(resize_h * ratio)
  155. resize_w = int(resize_w * ratio)
  156. max_stride = 128
  157. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  158. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  159. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  160. ratio_h = resize_h / float(h)
  161. ratio_w = resize_w / float(w)
  162. return img, [ratio_h, ratio_w]
  163. def resize_image_type3(self, img):
  164. """resize the image"""
  165. resize_c, resize_h, resize_w = self.input_shape # (c, h, w)
  166. ori_h, ori_w = img.shape[:2] # (h, w, c)
  167. ratio_h = float(resize_h) / ori_h
  168. ratio_w = float(resize_w) / ori_w
  169. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  170. return img, [ratio_h, ratio_w]
  171. @benchmark.timeit
  172. class NormalizeImage:
  173. """normalize image such as substract mean, divide std"""
  174. def __init__(self, scale=None, mean=None, std=None, order="chw"):
  175. super().__init__()
  176. if isinstance(scale, str):
  177. scale = eval(scale)
  178. self.order = order
  179. scale = scale if scale is not None else 1.0 / 255.0
  180. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  181. std = std if std is not None else [0.229, 0.224, 0.225]
  182. self.alpha = [scale / std[i] for i in range(len(std))]
  183. self.beta = [-mean[i] / std[i] for i in range(len(std))]
  184. def __call__(self, imgs):
  185. """apply"""
  186. def _norm(img):
  187. if self.order == "chw":
  188. img = np.transpose(img, (2, 0, 1))
  189. split_im = list(cv2.split(img))
  190. for c in range(img.shape[2]):
  191. split_im[c] = split_im[c].astype(np.float32)
  192. split_im[c] *= self.alpha[c]
  193. split_im[c] += self.beta[c]
  194. res = cv2.merge(split_im)
  195. if self.order == "chw":
  196. res = np.transpose(res, (1, 2, 0))
  197. return res
  198. return [_norm(img) for img in imgs]
  199. @benchmark.timeit
  200. class DBPostProcess:
  201. """
  202. The post process for Differentiable Binarization (DB).
  203. """
  204. def __init__(
  205. self,
  206. thresh=0.3,
  207. box_thresh=0.7,
  208. max_candidates=1000,
  209. unclip_ratio=2.0,
  210. use_dilation=False,
  211. score_mode="fast",
  212. box_type="quad",
  213. **kwargs
  214. ):
  215. super().__init__()
  216. self.thresh = thresh
  217. self.box_thresh = box_thresh
  218. self.max_candidates = max_candidates
  219. self.unclip_ratio = unclip_ratio
  220. self.min_size = 3
  221. self.score_mode = score_mode
  222. self.box_type = box_type
  223. assert score_mode in [
  224. "slow",
  225. "fast",
  226. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  227. self.use_dilation = use_dilation
  228. def polygons_from_bitmap(
  229. self,
  230. pred,
  231. _bitmap,
  232. dest_width,
  233. dest_height,
  234. box_thresh,
  235. unclip_ratio,
  236. ):
  237. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  238. bitmap = _bitmap
  239. height, width = bitmap.shape
  240. width_scale = dest_width / width
  241. height_scale = dest_height / height
  242. boxes = []
  243. scores = []
  244. contours, _ = cv2.findContours(
  245. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  246. )
  247. for contour in contours[: self.max_candidates]:
  248. epsilon = 0.002 * cv2.arcLength(contour, True)
  249. approx = cv2.approxPolyDP(contour, epsilon, True)
  250. points = approx.reshape((-1, 2))
  251. if points.shape[0] < 4:
  252. continue
  253. score = self.box_score_fast(pred, points.reshape(-1, 2))
  254. if box_thresh > score:
  255. continue
  256. if points.shape[0] > 2:
  257. box = self.unclip(points, unclip_ratio)
  258. if len(box) > 1:
  259. continue
  260. else:
  261. continue
  262. box = box.reshape(-1, 2)
  263. if len(box) > 0:
  264. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  265. if sside < self.min_size + 2:
  266. continue
  267. else:
  268. continue
  269. box = np.array(box)
  270. for i in range(box.shape[0]):
  271. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  272. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  273. boxes.append(box)
  274. scores.append(score)
  275. return boxes, scores
  276. def boxes_from_bitmap(
  277. self,
  278. pred,
  279. _bitmap,
  280. dest_width,
  281. dest_height,
  282. box_thresh,
  283. unclip_ratio,
  284. ):
  285. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  286. bitmap = _bitmap
  287. height, width = bitmap.shape
  288. width_scale = dest_width / width
  289. height_scale = dest_height / height
  290. outs = cv2.findContours(
  291. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  292. )
  293. if len(outs) == 3:
  294. img, contours, _ = outs[0], outs[1], outs[2]
  295. elif len(outs) == 2:
  296. contours, _ = outs[0], outs[1]
  297. num_contours = min(len(contours), self.max_candidates)
  298. boxes = []
  299. scores = []
  300. for index in range(num_contours):
  301. contour = contours[index]
  302. points, sside = self.get_mini_boxes(contour)
  303. if sside < self.min_size:
  304. continue
  305. points = np.array(points)
  306. if self.score_mode == "fast":
  307. score = self.box_score_fast(pred, points.reshape(-1, 2))
  308. else:
  309. score = self.box_score_slow(pred, contour)
  310. if box_thresh > score:
  311. continue
  312. box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
  313. box, sside = self.get_mini_boxes(box)
  314. if sside < self.min_size + 2:
  315. continue
  316. box = np.array(box)
  317. for i in range(box.shape[0]):
  318. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  319. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  320. boxes.append(box.astype(np.int16))
  321. scores.append(score)
  322. return np.array(boxes, dtype=np.int16), scores
  323. def unclip(self, box, unclip_ratio):
  324. """unclip"""
  325. area = cv2.contourArea(box)
  326. length = cv2.arcLength(box, True)
  327. distance = area * unclip_ratio / length
  328. offset = pyclipper.PyclipperOffset()
  329. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  330. try:
  331. expanded = np.array(offset.Execute(distance))
  332. except ValueError:
  333. expanded = np.array(offset.Execute(distance)[0])
  334. return expanded
  335. def get_mini_boxes(self, contour):
  336. """get mini boxes"""
  337. bounding_box = cv2.minAreaRect(contour)
  338. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  339. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  340. if points[1][1] > points[0][1]:
  341. index_1 = 0
  342. index_4 = 1
  343. else:
  344. index_1 = 1
  345. index_4 = 0
  346. if points[3][1] > points[2][1]:
  347. index_2 = 2
  348. index_3 = 3
  349. else:
  350. index_2 = 3
  351. index_3 = 2
  352. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  353. return box, min(bounding_box[1])
  354. def box_score_fast(self, bitmap, _box):
  355. """box_score_fast: use bbox mean score as the mean score"""
  356. h, w = bitmap.shape[:2]
  357. box = _box.copy()
  358. xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
  359. xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
  360. ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
  361. ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
  362. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  363. box[:, 0] = box[:, 0] - xmin
  364. box[:, 1] = box[:, 1] - ymin
  365. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  366. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  367. def box_score_slow(self, bitmap, contour):
  368. """box_score_slow: use polygon mean score as the mean score"""
  369. h, w = bitmap.shape[:2]
  370. contour = contour.copy()
  371. contour = np.reshape(contour, (-1, 2))
  372. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  373. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  374. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  375. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  376. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  377. contour[:, 0] = contour[:, 0] - xmin
  378. contour[:, 1] = contour[:, 1] - ymin
  379. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  380. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  381. def __call__(
  382. self,
  383. preds,
  384. img_shapes,
  385. thresh: Union[float, None] = None,
  386. box_thresh: Union[float, None] = None,
  387. unclip_ratio: Union[float, None] = None,
  388. ):
  389. """apply"""
  390. boxes, scores = [], []
  391. for pred, img_shape in zip(preds[0], img_shapes):
  392. box, score = self.process(
  393. pred,
  394. img_shape,
  395. thresh or self.thresh,
  396. box_thresh or self.box_thresh,
  397. unclip_ratio or self.unclip_ratio,
  398. )
  399. boxes.append(box)
  400. scores.append(score)
  401. return boxes, scores
  402. def process(
  403. self,
  404. pred,
  405. img_shape,
  406. thresh,
  407. box_thresh,
  408. unclip_ratio,
  409. ):
  410. pred = pred[0, :, :]
  411. segmentation = pred > thresh
  412. dilation_kernel = None if not self.use_dilation else np.array([[1, 1], [1, 1]])
  413. src_h, src_w, ratio_h, ratio_w = img_shape
  414. if dilation_kernel is not None:
  415. mask = cv2.dilate(
  416. np.array(segmentation).astype(np.uint8),
  417. dilation_kernel,
  418. )
  419. else:
  420. mask = segmentation
  421. if self.box_type == "poly":
  422. boxes, scores = self.polygons_from_bitmap(
  423. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  424. )
  425. elif self.box_type == "quad":
  426. boxes, scores = self.boxes_from_bitmap(
  427. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  428. )
  429. else:
  430. raise ValueError("box_type can only be one of ['quad', 'poly']")
  431. return boxes, scores