浏览代码

opt processors (#3714)

* opt Normalize op

* perform an in-place operation on 'split_im'

* opt DBPostProcess op

* format text_detection postprocess && opt object_detection Normalize operation

* opt common and object_detection Normalize operation

* add short circuit checks for Normalize and Resize

* add shor circuit check for Pad operation, opt table_structure_recognition TableLabelDecode operation

* fix Pad's short circuit check

* fix Pad's short circuit check

* Optimize rotate_image

* Fix bug

* use element-wise multiplication for bbox scaling

* fix varaiable name

* rgb2bgr contiguous array

* Polish code

* Fix bug

* imporve benchmark

* remove  parameter from NormalizeImage

* opt handling of 'location not found' case

* Fix code style

* Fix imports

---------

Co-authored-by: Bobholamovic <bob1998425@hotmail.com>
zhang-prog 7 月之前
父节点
当前提交
5b71ecdafb

+ 2 - 2
paddlex/inference/common/reader/image_reader.py

@@ -49,7 +49,7 @@ class ReadImage:
     def read(self, img):
     def read(self, img):
         if isinstance(img, np.ndarray):
         if isinstance(img, np.ndarray):
             if self.format == "RGB":
             if self.format == "RGB":
-                img = img[:, :, ::-1]
+                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
             return img
             return img
         elif isinstance(img, str):
         elif isinstance(img, str):
             blob = self._img_reader.read(img)
             blob = self._img_reader.read(img)
@@ -60,7 +60,7 @@ class ReadImage:
                 if blob.ndim != 3:
                 if blob.ndim != 3:
                     raise RuntimeError("Array is not 3-dimensional.")
                     raise RuntimeError("Array is not 3-dimensional.")
                 # BGR to RGB
                 # BGR to RGB
-                blob = blob[..., ::-1]
+                blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
             return blob
             return blob
         else:
         else:
             raise TypeError(
             raise TypeError(

+ 7 - 10
paddlex/inference/models/common/vision/funcs.py

@@ -33,6 +33,8 @@ def check_image_size(input_):
 def resize(im, target_size, interp, backend="cv2"):
 def resize(im, target_size, interp, backend="cv2"):
     """resize image to target size"""
     """resize image to target size"""
     w, h = target_size
     w, h = target_size
+    if w == im.shape[1] and h == im.shape[0]:
+        return im
     if backend.lower() == "pil":
     if backend.lower() == "pil":
         resize_function = _pil_resize
         resize_function = _pil_resize
     else:
     else:
@@ -60,20 +62,12 @@ def _pil_resize(src, size, resample):
 
 
 def flip_h(im):
 def flip_h(im):
     """flip image horizontally"""
     """flip image horizontally"""
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
-    return im
+    return cv2.flip(im, 1)
 
 
 
 
 def flip_v(im):
 def flip_v(im):
     """flip image vertically"""
     """flip image vertically"""
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
-    return im
+    return cv2.flip(im, 0)
 
 
 
 
 def slice(im, coords):
 def slice(im, coords):
@@ -89,6 +83,9 @@ def pad(im, pad, val):
         pad = [pad] * 4
         pad = [pad] * 4
     if len(pad) != 4:
     if len(pad) != 4:
         raise ValueError
         raise ValueError
+    if all(x == 0 for x in pad):
+        return im
+
     chns = 1 if im.ndim == 2 else im.shape[2]
     chns = 1 if im.ndim == 2 else im.shape[2]
     im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns)
     im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns)
     return im
     return im

+ 28 - 21
paddlex/inference/models/common/vision/processors.py

@@ -217,9 +217,9 @@ class ResizeByShort(_BaseResize):
 
 
 @benchmark.timeit
 @benchmark.timeit
 class Normalize:
 class Normalize:
-    """Normalize the image."""
+    """Normalize the three-channel image."""
 
 
-    def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
+    def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5):
         """
         """
         Initialize the instance.
         Initialize the instance.
 
 
@@ -228,34 +228,41 @@ class Normalize:
                 applying normalization. Default: 1/255.
                 applying normalization. Default: 1/255.
             mean (float|tuple|list, optional): Means for each channel of the image.
             mean (float|tuple|list, optional): Means for each channel of the image.
                 Default: 0.5.
                 Default: 0.5.
-            std (float|tuple|list, optional): Standard deviations for each channel
+            std (float|tuple|list|np.ndarray, optional): Standard deviations for each channel
                 of the image. Default: 0.5.
                 of the image. Default: 0.5.
-            preserve_dtype (bool, optional): Whether to preserve the original dtype
-                of the image.
         """
         """
         super().__init__()
         super().__init__()
 
 
-        self.scale = np.float32(scale)
         if isinstance(mean, float):
         if isinstance(mean, float):
-            mean = [mean]
-        self.mean = np.asarray(mean).astype("float32")
+            mean = [mean] * 3
+        elif len(mean) != 3:
+            raise ValueError(
+                f"Expected `mean` to be a tuple or list of length 3, but got {len(mean)} elements."
+            )
         if isinstance(std, float):
         if isinstance(std, float):
-            std = [std]
-        self.std = np.asarray(std).astype("float32")
-        self.preserve_dtype = preserve_dtype
+            std = [std] * 3
+        elif len(std) != 3:
+            raise ValueError(
+                f"Expected `std` to be a tuple or list of length 3, but got {len(std)} elements."
+            )
+
+        self.alpha = [scale / std[i] for i in range(len(std))]
+        self.beta = [-mean[i] / std[i] for i in range(len(std))]
+
+    def norm(self, img):
+        split_im = list(cv2.split(img))
+
+        for c in range(img.shape[2]):
+            split_im[c] = split_im[c].astype(np.float32)
+            split_im[c] *= self.alpha[c]
+            split_im[c] += self.beta[c]
+
+        res = cv2.merge(split_im)
+        return res
 
 
     def __call__(self, imgs):
     def __call__(self, imgs):
         """apply"""
         """apply"""
-        old_type = imgs[0].dtype
-        # XXX: If `old_type` has higher precision than float32,
-        # we will lose some precision.
-        imgs = np.array(imgs).astype("float32", copy=False)
-        imgs *= self.scale
-        imgs -= self.mean
-        imgs /= self.std
-        if self.preserve_dtype:
-            imgs = imgs.astype(old_type, copy=False)
-        return list(imgs)
+        return [self.norm(img) for img in imgs]
 
 
 
 
 @benchmark.timeit
 @benchmark.timeit

+ 4 - 19
paddlex/inference/models/object_detection/processors.py

@@ -27,7 +27,7 @@ Boxes = List[dict]
 Number = Union[int, float]
 Number = Union[int, float]
 
 
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadImage(CommonReadImage):
 class ReadImage(CommonReadImage):
     """Reads images from a list of raw image data or file paths."""
     """Reads images from a list of raw image data or file paths."""
 
 
@@ -71,7 +71,7 @@ class ReadImage(CommonReadImage):
         if isinstance(img, np.ndarray):
         if isinstance(img, np.ndarray):
             ori_img = img
             ori_img = img
             if self.format == "RGB":
             if self.format == "RGB":
-                img = img[:, :, ::-1]
+                img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
             return img, ori_img
             return img, ori_img
         elif isinstance(img, str):
         elif isinstance(img, str):
             blob = self._img_reader.read(img)
             blob = self._img_reader.read(img)
@@ -83,7 +83,7 @@ class ReadImage(CommonReadImage):
                 if blob.ndim != 3:
                 if blob.ndim != 3:
                     raise RuntimeError("Array is not 3-dimensional.")
                     raise RuntimeError("Array is not 3-dimensional.")
                 # BGR to RGB
                 # BGR to RGB
-                blob = blob[..., ::-1]
+                blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
             return blob, ori_img
             return blob, ori_img
         else:
         else:
             raise TypeError(
             raise TypeError(
@@ -127,27 +127,12 @@ class Resize(CommonResize):
 
 
 @benchmark.timeit
 @benchmark.timeit
 class Normalize(CommonNormalize):
 class Normalize(CommonNormalize):
-    """Normalizes images in a list of dictionaries containing image data"""
-
-    def apply(self, img: ndarray) -> ndarray:
-        """Applies normalization to a single image."""
-        old_type = img.dtype
-        # XXX: If `old_type` has higher precision than float32,
-        # we will lose some precision.
-        img = img.astype("float32", copy=False)
-        img *= self.scale
-        img -= self.mean
-        img /= self.std
-        if self.preserve_dtype:
-            img = img.astype(old_type, copy=False)
-        return img
-
     def __call__(self, datas: List[dict]) -> List[dict]:
     def __call__(self, datas: List[dict]) -> List[dict]:
         """Normalizes images in a list of dictionaries. Iterates over each dictionary,
         """Normalizes images in a list of dictionaries. Iterates over each dictionary,
         applies normalization to the 'img' key, and returns the modified list.
         applies normalization to the 'img' key, and returns the modified list.
         """
         """
         for data in datas:
         for data in datas:
-            data["img"] = self.apply(data["img"])
+            data["img"] = self.norm(data["img"])
         return datas
         return datas
 
 
 
 

+ 18 - 26
paddlex/inference/models/table_structure_recognition/processors.py

@@ -123,16 +123,8 @@ class TableLabelDecode:
 
 
     def __call__(self, pred, img_size, ori_img_size):
     def __call__(self, pred, img_size, ori_img_size):
         """apply"""
         """apply"""
-        bbox_preds, structure_probs = [], []
-
-        for i in range(len(pred[0][0])):
-            bbox_preds.append(pred[0][0][i])
-            structure_probs.append(pred[1][0][i])
-        bbox_preds = [bbox_preds]
-        structure_probs = [structure_probs]
-
-        bbox_preds = np.array(bbox_preds)
-        structure_probs = np.array(structure_probs)
+        bbox_preds = np.array([list(pred[0][0])])
+        structure_probs = np.array([list(pred[1][0])])
 
 
         bbox_list, structure_str_list, structure_score = self.decode(
         bbox_list, structure_str_list, structure_score = self.decode(
             structure_probs, bbox_preds, img_size, ori_img_size
             structure_probs, bbox_preds, img_size, ori_img_size
@@ -161,9 +153,11 @@ class TableLabelDecode:
         structure_batch_list = []
         structure_batch_list = []
         bbox_batch_list = []
         bbox_batch_list = []
         batch_size = len(structure_idx)
         batch_size = len(structure_idx)
+        bbox_list = []
+        scale_list = []
+        scales = [0] * 8
         for batch_idx in range(batch_size):
         for batch_idx in range(batch_size):
             structure_list = []
             structure_list = []
-            bbox_list = []
             score_list = []
             score_list = []
             for idx in range(len(structure_idx[batch_idx])):
             for idx in range(len(structure_idx[batch_idx])):
                 char_idx = int(structure_idx[batch_idx][idx])
                 char_idx = int(structure_idx[batch_idx][idx])
@@ -174,15 +168,21 @@ class TableLabelDecode:
                 text = self.character[char_idx]
                 text = self.character[char_idx]
                 if text in self.td_token:
                 if text in self.td_token:
                     bbox = bbox_preds[batch_idx, idx]
                     bbox = bbox_preds[batch_idx, idx]
-                    bbox = self._bbox_decode(
-                        bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+                    h_scale, w_scale = self._get_bbox_scales(
+                        padding_size[batch_idx], ori_img_size[batch_idx]
                     )
                     )
-                    bbox_list.append(bbox.astype(int))
+                    scales[0::2] = [h_scale] * 4
+                    scales[1::2] = [w_scale] * 4
+                    bbox_list.append(bbox)
+                    scale_list.append(scales)
+
                 structure_list.append(text)
                 structure_list.append(text)
                 score_list.append(structure_probs[batch_idx, idx])
                 score_list.append(structure_probs[batch_idx, idx])
             structure_batch_list.append(structure_list)
             structure_batch_list.append(structure_list)
             structure_score = np.mean(score_list)
             structure_score = np.mean(score_list)
-            bbox_batch_list.append(bbox_list)
+
+        bbox_batch_array = np.multiply(np.array(bbox_list), np.array(scale_list))
+        bbox_batch_list = [bbox_batch_array.astype(int).tolist()]
 
 
         return bbox_batch_list, structure_batch_list, structure_score
         return bbox_batch_list, structure_batch_list, structure_score
 
 
@@ -216,22 +216,14 @@ class TableLabelDecode:
             bbox_batch_list.append(bbox_list)
             bbox_batch_list.append(bbox_list)
         return bbox_batch_list, structure_batch_list
         return bbox_batch_list, structure_batch_list
 
 
-    def _bbox_decode(self, bbox, padding_shape, ori_shape):
-
+    def _get_bbox_scales(self, padding_shape, ori_shape):
         if self.model_name == "SLANet":
         if self.model_name == "SLANet":
             w, h = ori_shape
             w, h = ori_shape
-            bbox[0::2] *= w
-            bbox[1::2] *= h
+            return w, h
         else:
         else:
             w, h = padding_shape
             w, h = padding_shape
             ori_w, ori_h = ori_shape
             ori_w, ori_h = ori_shape
             ratio_w = w / ori_w
             ratio_w = w / ori_w
             ratio_h = h / ori_h
             ratio_h = h / ori_h
             ratio = min(ratio_w, ratio_h)
             ratio = min(ratio_w, ratio_h)
-
-            bbox[0::2] *= w
-            bbox[1::2] *= h
-            bbox[0::2] /= ratio
-            bbox[1::2] /= ratio
-
-        return bbox
+            return w / ratio, h / ratio

+ 1 - 4
paddlex/inference/models/text_detection/predictor.py

@@ -152,11 +152,8 @@ class TextDetPredictor(BasePredictor):
         std=[0.229, 0.224, 0.225],
         std=[0.229, 0.224, 0.225],
         scale=1 / 255,
         scale=1 / 255,
         order="",
         order="",
-        channel_num=3,
     ):
     ):
-        return "Normalize", NormalizeImage(
-            mean=mean, std=std, scale=scale, order=order, channel_num=channel_num
-        )
+        return "Normalize", NormalizeImage(mean=mean, std=std, scale=scale, order=order)
 
 
     @register("ToCHWImage")
     @register("ToCHWImage")
     def build_to_chw(self):
     def build_to_chw(self):

+ 41 - 24
paddlex/inference/models/text_detection/processors.py

@@ -19,7 +19,6 @@ from typing import Union
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import pyclipper
 import pyclipper
-from shapely.geometry import Polygon
 
 
 from ....utils import logging
 from ....utils import logging
 from ...utils.benchmark import benchmark
 from ...utils.benchmark import benchmark
@@ -197,25 +196,39 @@ class DetResizeForTest:
 class NormalizeImage:
 class NormalizeImage:
     """normalize image such as substract mean, divide std"""
     """normalize image such as substract mean, divide std"""
 
 
-    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+    def __init__(self, scale=None, mean=None, std=None, order="chw"):
         super().__init__()
         super().__init__()
         if isinstance(scale, str):
         if isinstance(scale, str):
             scale = eval(scale)
             scale = eval(scale)
-        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        self.order = order
+
+        scale = scale if scale is not None else 1.0 / 255.0
         mean = mean if mean is not None else [0.485, 0.456, 0.406]
         mean = mean if mean is not None else [0.485, 0.456, 0.406]
         std = std if std is not None else [0.229, 0.224, 0.225]
         std = std if std is not None else [0.229, 0.224, 0.225]
 
 
-        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
-        self.mean = np.array(mean).reshape(shape).astype("float32")
-        self.std = np.array(std).reshape(shape).astype("float32")
+        self.alpha = [scale / std[i] for i in range(len(std))]
+        self.beta = [-mean[i] / std[i] for i in range(len(std))]
 
 
     def __call__(self, imgs):
     def __call__(self, imgs):
         """apply"""
         """apply"""
 
 
-        def norm(img):
-            return (img.astype("float32") * self.scale - self.mean) / self.std
+        def _norm(img):
+            if self.order == "chw":
+                img = np.transpose(img, (2, 0, 1))
+
+            split_im = list(cv2.split(img))
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= self.alpha[c]
+                split_im[c] += self.beta[c]
+
+            res = cv2.merge(split_im)
 
 
-        return [norm(img) for img in imgs]
+            if self.order == "chw":
+                res = np.transpose(res, (1, 2, 0))
+            return res
+
+        return [_norm(img) for img in imgs]
 
 
 
 
 @benchmark.timeit
 @benchmark.timeit
@@ -262,7 +275,8 @@ class DBPostProcess:
 
 
         bitmap = _bitmap
         bitmap = _bitmap
         height, width = bitmap.shape
         height, width = bitmap.shape
-
+        width_scale = dest_width / width
+        height_scale = dest_height / height
         boxes = []
         boxes = []
         scores = []
         scores = []
 
 
@@ -297,10 +311,10 @@ class DBPostProcess:
                 continue
                 continue
 
 
             box = np.array(box)
             box = np.array(box)
-            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
-            box[:, 1] = np.clip(
-                np.round(box[:, 1] / height * dest_height), 0, dest_height
-            )
+            for i in range(box.shape[0]):
+                box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
+                box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
+
             boxes.append(box)
             boxes.append(box)
             scores.append(score)
             scores.append(score)
         return boxes, scores
         return boxes, scores
@@ -318,6 +332,8 @@ class DBPostProcess:
 
 
         bitmap = _bitmap
         bitmap = _bitmap
         height, width = bitmap.shape
         height, width = bitmap.shape
+        width_scale = dest_width / width
+        height_scale = dest_height / height
 
 
         outs = cv2.findContours(
         outs = cv2.findContours(
             (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
             (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
@@ -348,20 +364,21 @@ class DBPostProcess:
             box, sside = self.get_mini_boxes(box)
             box, sside = self.get_mini_boxes(box)
             if sside < self.min_size + 2:
             if sside < self.min_size + 2:
                 continue
                 continue
+
             box = np.array(box)
             box = np.array(box)
+            for i in range(box.shape[0]):
+                box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
+                box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
 
 
-            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
-            box[:, 1] = np.clip(
-                np.round(box[:, 1] / height * dest_height), 0, dest_height
-            )
             boxes.append(box.astype(np.int16))
             boxes.append(box.astype(np.int16))
             scores.append(score)
             scores.append(score)
         return np.array(boxes, dtype=np.int16), scores
         return np.array(boxes, dtype=np.int16), scores
 
 
     def unclip(self, box, unclip_ratio):
     def unclip(self, box, unclip_ratio):
         """unclip"""
         """unclip"""
-        poly = Polygon(box)
-        distance = poly.area * unclip_ratio / poly.length
+        area = cv2.contourArea(box)
+        length = cv2.arcLength(box, True)
+        distance = area * unclip_ratio / length
         offset = pyclipper.PyclipperOffset()
         offset = pyclipper.PyclipperOffset()
         offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         try:
         try:
@@ -396,10 +413,10 @@ class DBPostProcess:
         """box_score_fast: use bbox mean score as the mean score"""
         """box_score_fast: use bbox mean score as the mean score"""
         h, w = bitmap.shape[:2]
         h, w = bitmap.shape[:2]
         box = _box.copy()
         box = _box.copy()
-        xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
-        xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
-        ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
-        ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
+        xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
+        xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
+        ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
+        ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
 
 
         mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         box[:, 0] = box[:, 0] - xmin
         box[:, 0] = box[:, 0] - xmin

+ 1 - 0
paddlex/inference/pipelines/components/__init__.py

@@ -21,6 +21,7 @@ from .common import (
     SortPolyBoxes,
     SortPolyBoxes,
     SortQuadBoxes,
     SortQuadBoxes,
     convert_points_to_boxes,
     convert_points_to_boxes,
+    rotate_image,
 )
 )
 from .faisser import FaissBuilder, FaissIndexer, IndexData
 from .faisser import FaissBuilder, FaissIndexer, IndexData
 from .prompt_engineering.base import BaseGeneratePrompt
 from .prompt_engineering.base import BaseGeneratePrompt

+ 1 - 0
paddlex/inference/pipelines/components/common/__init__.py

@@ -16,3 +16,4 @@ from .base_result import BaseResult, CVResult
 from .convert_points_and_boxes import convert_points_to_boxes
 from .convert_points_and_boxes import convert_points_to_boxes
 from .crop_image_regions import CropByBoxes, CropByPolys
 from .crop_image_regions import CropByBoxes, CropByPolys
 from .sort_boxes import SortPolyBoxes, SortQuadBoxes
 from .sort_boxes import SortPolyBoxes, SortQuadBoxes
+from .warp_image import rotate_image

+ 45 - 0
paddlex/inference/pipelines/components/common/warp_image.py

@@ -0,0 +1,45 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+
+
+def rotate_image(image, angle):
+    if angle < 0 or angle >= 360:
+        raise ValueError("`angle` should be in range [0, 360)")
+
+    if angle < 1e-7:
+        return image
+
+    # Should we align corners?
+    h, w = image.shape[:2]
+    center = (w / 2, h / 2)
+    scale = 1.0
+    mat = cv2.getRotationMatrix2D(center, angle, scale)
+    cos = np.abs(mat[0, 0])
+    sin = np.abs(mat[0, 1])
+    new_w = int((h * sin) + (w * cos))
+    new_h = int((h * cos) + (w * sin))
+    mat[0, 2] += (new_w - w) / 2
+    mat[1, 2] += (new_h - h) / 2
+    dst_size = (new_w, new_h)
+
+    rotated = cv2.warpAffine(
+        image,
+        mat,
+        dst_size,
+        flags=cv2.INTER_CUBIC,
+    )
+    return rotated

+ 2 - 21
paddlex/inference/pipelines/doc_preprocessor/pipeline.py

@@ -15,7 +15,6 @@
 from typing import Any, Dict, List, Optional, Union
 from typing import Any, Dict, List, Optional, Union
 
 
 import numpy as np
 import numpy as np
-from scipy.ndimage import rotate
 
 
 from ....utils import logging
 from ....utils import logging
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.batch_sampler import ImageBatchSampler
@@ -23,6 +22,7 @@ from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 from ..base import BasePipeline
+from ..components import rotate_image
 from .result import DocPreprocessorResult
 from .result import DocPreprocessorResult
 
 
 
 
@@ -77,25 +77,6 @@ class DocPreprocessorPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
 
 
-    def rotate_image(self, image_array: np.ndarray, rotate_angle: float) -> np.ndarray:
-        """
-        Rotate the given image array by the specified angle.
-
-        Args:
-            image_array (np.ndarray): The input image array to be rotated.
-            rotate_angle (float): The angle in degrees by which to rotate the image.
-
-        Returns:
-            np.ndarray: The rotated image array.
-
-        Raises:
-            AssertionError: If rotate_angle is not in the range [0, 360).
-        """
-        assert (
-            rotate_angle >= 0 and rotate_angle < 360
-        ), "rotate_angle must in [0-360), but get {rotate_angle}."
-        return rotate(image_array, rotate_angle, reshape=True)
-
     def check_model_settings_valid(self, model_settings: Dict) -> bool:
     def check_model_settings_valid(self, model_settings: Dict) -> bool:
         """
         """
         Check if the the input params for model settings are valid based on the initialized models.
         Check if the the input params for model settings are valid based on the initialized models.
@@ -178,7 +159,7 @@ class DocPreprocessorPipeline(BasePipeline):
             if model_settings["use_doc_orientation_classify"]:
             if model_settings["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
                 pred = next(self.doc_ori_classify_model(image_array))
                 angle = int(pred["label_names"][0])
                 angle = int(pred["label_names"][0])
-                rot_img = self.rotate_image(image_array, angle)
+                rot_img = rotate_image(image_array, angle)
             else:
             else:
                 angle = -1
                 angle = -1
                 rot_img = image_array
                 rot_img = image_array

+ 2 - 2
paddlex/inference/pipelines/ocr/pipeline.py

@@ -15,7 +15,6 @@
 from typing import Any, Dict, List, Optional, Union
 from typing import Any, Dict, List, Optional, Union
 
 
 import numpy as np
 import numpy as np
-from scipy.ndimage import rotate
 
 
 from ....utils import logging
 from ....utils import logging
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.batch_sampler import ImageBatchSampler
@@ -28,6 +27,7 @@ from ..components import (
     SortPolyBoxes,
     SortPolyBoxes,
     SortQuadBoxes,
     SortQuadBoxes,
     convert_points_to_boxes,
     convert_points_to_boxes,
+    rotate_image,
 )
 )
 from .result import OCRResult
 from .result import OCRResult
 
 
@@ -163,7 +163,7 @@ class OCRPipeline(BasePipeline):
         for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
         for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
             # Convert 0/1 indicator to actual rotation angle
             # Convert 0/1 indicator to actual rotation angle
             rotate_angle = rotate_indicator * 180
             rotate_angle = rotate_indicator * 180
-            rotated_image = rotate(image_array, rotate_angle, reshape=True)
+            rotated_image = rotate_image(image_array, rotate_angle)
             rotated_images.append(rotated_image)
             rotated_images.append(rotated_image)
 
 
         return rotated_images
         return rotated_images

+ 32 - 3
paddlex/inference/utils/benchmark.py

@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+import copy
 import csv
 import csv
 import functools
 import functools
 import inspect
 import inspect
 import time
 import time
+import uuid
 from pathlib import Path
 from pathlib import Path
 from types import GeneratorType
 from types import GeneratorType
 
 
@@ -34,6 +36,8 @@ ENTRY_POINT_NAME = "_entry_point_"
 # XXX: Global mutable state
 # XXX: Global mutable state
 _inference_operations = []
 _inference_operations = []
 
 
+_is_measuring_time = False
+
 
 
 class Benchmark:
 class Benchmark:
     def __init__(self, enabled):
     def __init__(self, enabled):
@@ -65,7 +69,7 @@ class Benchmark:
                 source_line = inspect.getsourcelines(func)[1]
                 source_line = inspect.getsourcelines(func)[1]
                 location = f"{source_file}:{source_line}"
                 location = f"{source_file}:{source_line}"
             except (TypeError, OSError) as e:
             except (TypeError, OSError) as e:
-                location = "Unknown"
+                location = uuid.uuid4().hex
                 logging.debug(
                 logging.debug(
                     f"Benchmark: failed to get source file and line number: {e}"
                     f"Benchmark: failed to get source file and line number: {e}"
                 )
                 )
@@ -89,15 +93,27 @@ class Benchmark:
                         for k, v in kwargs.items()
                         for k, v in kwargs.items()
                     }
                     }
                     output = func(*args, **kwargs)
                     output = func(*args, **kwargs)
+                    output = copy.deepcopy(output)
                     return output
                     return output
 
 
             else:
             else:
 
 
                 @functools.wraps(func)
                 @functools.wraps(func)
                 def _wrapper(*args, **kwargs):
                 def _wrapper(*args, **kwargs):
+                    global _is_measuring_time
                     operation_name = f"{name}@{location}"
                     operation_name = f"{name}@{location}"
+                    if _is_measuring_time:
+                        raise RuntimeError(
+                            "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
+                        )
+                    if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
+                        _is_measuring_time = True
                     tic = time.perf_counter()
                     tic = time.perf_counter()
-                    output = func(*args, **kwargs)
+                    try:
+                        output = func(*args, **kwargs)
+                    finally:
+                        if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
+                            _is_measuring_time = False
                     if isinstance(output, GeneratorType):
                     if isinstance(output, GeneratorType):
                         return self.watch_generator(output, operation_name)
                         return self.watch_generator(output, operation_name)
                     else:
                     else:
@@ -118,10 +134,21 @@ class Benchmark:
     def watch_generator(self, generator, name):
     def watch_generator(self, generator, name):
         @functools.wraps(generator)
         @functools.wraps(generator)
         def wrapper():
         def wrapper():
+            global _is_measuring_time
             while True:
             while True:
                 try:
                 try:
+                    if _is_measuring_time:
+                        raise RuntimeError(
+                            "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
+                        )
+                    if not name.startswith(f"{ENTRY_POINT_NAME}@"):
+                        _is_measuring_time = True
                     tic = time.perf_counter()
                     tic = time.perf_counter()
-                    item = next(generator)
+                    try:
+                        item = next(generator)
+                    finally:
+                        if not name.startswith(f"{ENTRY_POINT_NAME}@"):
+                            _is_measuring_time = False
                     self._update(time.perf_counter() - tic, name)
                     self._update(time.perf_counter() - tic, name)
                     yield item
                     yield item
                 except StopIteration:
                 except StopIteration:
@@ -187,6 +214,8 @@ class Benchmark:
             avg = np.mean(time_list)
             avg = np.mean(time_list)
             operation_name = name.split("@")[0]
             operation_name = name.split("@")[0]
             location = name.split("@")[1]
             location = name.split("@")[1]
+            if ":" not in location:
+                location = "Unknown"
             detail_list.append(
             detail_list.append(
                 (iters, batch_size, instances, operation_name, avg, avg / batch_size)
                 (iters, batch_size, instances, operation_name, avg, avg / batch_size)
             )
             )