瀏覽代碼

opt processors (#3714)

* opt Normalize op

* perform an in-place operation on 'split_im'

* opt DBPostProcess op

* format text_detection postprocess && opt object_detection Normalize operation

* opt common and object_detection Normalize operation

* add short circuit checks for Normalize and Resize

* add shor circuit check for Pad operation, opt table_structure_recognition TableLabelDecode operation

* fix Pad's short circuit check

* fix Pad's short circuit check

* Optimize rotate_image

* Fix bug

* use element-wise multiplication for bbox scaling

* fix varaiable name

* rgb2bgr contiguous array

* Polish code

* Fix bug

* imporve benchmark

* remove  parameter from NormalizeImage

* opt handling of 'location not found' case

* Fix code style

* Fix imports

---------

Co-authored-by: Bobholamovic <bob1998425@hotmail.com>
zhang-prog 7 月之前
父節點
當前提交
5b71ecdafb

+ 2 - 2
paddlex/inference/common/reader/image_reader.py

@@ -49,7 +49,7 @@ class ReadImage:
     def read(self, img):
         if isinstance(img, np.ndarray):
             if self.format == "RGB":
-                img = img[:, :, ::-1]
+                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
             return img
         elif isinstance(img, str):
             blob = self._img_reader.read(img)
@@ -60,7 +60,7 @@ class ReadImage:
                 if blob.ndim != 3:
                     raise RuntimeError("Array is not 3-dimensional.")
                 # BGR to RGB
-                blob = blob[..., ::-1]
+                blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
             return blob
         else:
             raise TypeError(

+ 7 - 10
paddlex/inference/models/common/vision/funcs.py

@@ -33,6 +33,8 @@ def check_image_size(input_):
 def resize(im, target_size, interp, backend="cv2"):
     """resize image to target size"""
     w, h = target_size
+    if w == im.shape[1] and h == im.shape[0]:
+        return im
     if backend.lower() == "pil":
         resize_function = _pil_resize
     else:
@@ -60,20 +62,12 @@ def _pil_resize(src, size, resample):
 
 def flip_h(im):
     """flip image horizontally"""
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
-    return im
+    return cv2.flip(im, 1)
 
 
 def flip_v(im):
     """flip image vertically"""
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
-    return im
+    return cv2.flip(im, 0)
 
 
 def slice(im, coords):
@@ -89,6 +83,9 @@ def pad(im, pad, val):
         pad = [pad] * 4
     if len(pad) != 4:
         raise ValueError
+    if all(x == 0 for x in pad):
+        return im
+
     chns = 1 if im.ndim == 2 else im.shape[2]
     im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns)
     return im

+ 28 - 21
paddlex/inference/models/common/vision/processors.py

@@ -217,9 +217,9 @@ class ResizeByShort(_BaseResize):
 
 @benchmark.timeit
 class Normalize:
-    """Normalize the image."""
+    """Normalize the three-channel image."""
 
-    def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
+    def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5):
         """
         Initialize the instance.
 
@@ -228,34 +228,41 @@ class Normalize:
                 applying normalization. Default: 1/255.
             mean (float|tuple|list, optional): Means for each channel of the image.
                 Default: 0.5.
-            std (float|tuple|list, optional): Standard deviations for each channel
+            std (float|tuple|list|np.ndarray, optional): Standard deviations for each channel
                 of the image. Default: 0.5.
-            preserve_dtype (bool, optional): Whether to preserve the original dtype
-                of the image.
         """
         super().__init__()
 
-        self.scale = np.float32(scale)
         if isinstance(mean, float):
-            mean = [mean]
-        self.mean = np.asarray(mean).astype("float32")
+            mean = [mean] * 3
+        elif len(mean) != 3:
+            raise ValueError(
+                f"Expected `mean` to be a tuple or list of length 3, but got {len(mean)} elements."
+            )
         if isinstance(std, float):
-            std = [std]
-        self.std = np.asarray(std).astype("float32")
-        self.preserve_dtype = preserve_dtype
+            std = [std] * 3
+        elif len(std) != 3:
+            raise ValueError(
+                f"Expected `std` to be a tuple or list of length 3, but got {len(std)} elements."
+            )
+
+        self.alpha = [scale / std[i] for i in range(len(std))]
+        self.beta = [-mean[i] / std[i] for i in range(len(std))]
+
+    def norm(self, img):
+        split_im = list(cv2.split(img))
+
+        for c in range(img.shape[2]):
+            split_im[c] = split_im[c].astype(np.float32)
+            split_im[c] *= self.alpha[c]
+            split_im[c] += self.beta[c]
+
+        res = cv2.merge(split_im)
+        return res
 
     def __call__(self, imgs):
         """apply"""
-        old_type = imgs[0].dtype
-        # XXX: If `old_type` has higher precision than float32,
-        # we will lose some precision.
-        imgs = np.array(imgs).astype("float32", copy=False)
-        imgs *= self.scale
-        imgs -= self.mean
-        imgs /= self.std
-        if self.preserve_dtype:
-            imgs = imgs.astype(old_type, copy=False)
-        return list(imgs)
+        return [self.norm(img) for img in imgs]
 
 
 @benchmark.timeit

+ 4 - 19
paddlex/inference/models/object_detection/processors.py

@@ -27,7 +27,7 @@ Boxes = List[dict]
 Number = Union[int, float]
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadImage(CommonReadImage):
     """Reads images from a list of raw image data or file paths."""
 
@@ -71,7 +71,7 @@ class ReadImage(CommonReadImage):
         if isinstance(img, np.ndarray):
             ori_img = img
             if self.format == "RGB":
-                img = img[:, :, ::-1]
+                img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
             return img, ori_img
         elif isinstance(img, str):
             blob = self._img_reader.read(img)
@@ -83,7 +83,7 @@ class ReadImage(CommonReadImage):
                 if blob.ndim != 3:
                     raise RuntimeError("Array is not 3-dimensional.")
                 # BGR to RGB
-                blob = blob[..., ::-1]
+                blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
             return blob, ori_img
         else:
             raise TypeError(
@@ -127,27 +127,12 @@ class Resize(CommonResize):
 
 @benchmark.timeit
 class Normalize(CommonNormalize):
-    """Normalizes images in a list of dictionaries containing image data"""
-
-    def apply(self, img: ndarray) -> ndarray:
-        """Applies normalization to a single image."""
-        old_type = img.dtype
-        # XXX: If `old_type` has higher precision than float32,
-        # we will lose some precision.
-        img = img.astype("float32", copy=False)
-        img *= self.scale
-        img -= self.mean
-        img /= self.std
-        if self.preserve_dtype:
-            img = img.astype(old_type, copy=False)
-        return img
-
     def __call__(self, datas: List[dict]) -> List[dict]:
         """Normalizes images in a list of dictionaries. Iterates over each dictionary,
         applies normalization to the 'img' key, and returns the modified list.
         """
         for data in datas:
-            data["img"] = self.apply(data["img"])
+            data["img"] = self.norm(data["img"])
         return datas
 
 

+ 18 - 26
paddlex/inference/models/table_structure_recognition/processors.py

@@ -123,16 +123,8 @@ class TableLabelDecode:
 
     def __call__(self, pred, img_size, ori_img_size):
         """apply"""
-        bbox_preds, structure_probs = [], []
-
-        for i in range(len(pred[0][0])):
-            bbox_preds.append(pred[0][0][i])
-            structure_probs.append(pred[1][0][i])
-        bbox_preds = [bbox_preds]
-        structure_probs = [structure_probs]
-
-        bbox_preds = np.array(bbox_preds)
-        structure_probs = np.array(structure_probs)
+        bbox_preds = np.array([list(pred[0][0])])
+        structure_probs = np.array([list(pred[1][0])])
 
         bbox_list, structure_str_list, structure_score = self.decode(
             structure_probs, bbox_preds, img_size, ori_img_size
@@ -161,9 +153,11 @@ class TableLabelDecode:
         structure_batch_list = []
         bbox_batch_list = []
         batch_size = len(structure_idx)
+        bbox_list = []
+        scale_list = []
+        scales = [0] * 8
         for batch_idx in range(batch_size):
             structure_list = []
-            bbox_list = []
             score_list = []
             for idx in range(len(structure_idx[batch_idx])):
                 char_idx = int(structure_idx[batch_idx][idx])
@@ -174,15 +168,21 @@ class TableLabelDecode:
                 text = self.character[char_idx]
                 if text in self.td_token:
                     bbox = bbox_preds[batch_idx, idx]
-                    bbox = self._bbox_decode(
-                        bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+                    h_scale, w_scale = self._get_bbox_scales(
+                        padding_size[batch_idx], ori_img_size[batch_idx]
                     )
-                    bbox_list.append(bbox.astype(int))
+                    scales[0::2] = [h_scale] * 4
+                    scales[1::2] = [w_scale] * 4
+                    bbox_list.append(bbox)
+                    scale_list.append(scales)
+
                 structure_list.append(text)
                 score_list.append(structure_probs[batch_idx, idx])
             structure_batch_list.append(structure_list)
             structure_score = np.mean(score_list)
-            bbox_batch_list.append(bbox_list)
+
+        bbox_batch_array = np.multiply(np.array(bbox_list), np.array(scale_list))
+        bbox_batch_list = [bbox_batch_array.astype(int).tolist()]
 
         return bbox_batch_list, structure_batch_list, structure_score
 
@@ -216,22 +216,14 @@ class TableLabelDecode:
             bbox_batch_list.append(bbox_list)
         return bbox_batch_list, structure_batch_list
 
-    def _bbox_decode(self, bbox, padding_shape, ori_shape):
-
+    def _get_bbox_scales(self, padding_shape, ori_shape):
         if self.model_name == "SLANet":
             w, h = ori_shape
-            bbox[0::2] *= w
-            bbox[1::2] *= h
+            return w, h
         else:
             w, h = padding_shape
             ori_w, ori_h = ori_shape
             ratio_w = w / ori_w
             ratio_h = h / ori_h
             ratio = min(ratio_w, ratio_h)
-
-            bbox[0::2] *= w
-            bbox[1::2] *= h
-            bbox[0::2] /= ratio
-            bbox[1::2] /= ratio
-
-        return bbox
+            return w / ratio, h / ratio

+ 1 - 4
paddlex/inference/models/text_detection/predictor.py

@@ -152,11 +152,8 @@ class TextDetPredictor(BasePredictor):
         std=[0.229, 0.224, 0.225],
         scale=1 / 255,
         order="",
-        channel_num=3,
     ):
-        return "Normalize", NormalizeImage(
-            mean=mean, std=std, scale=scale, order=order, channel_num=channel_num
-        )
+        return "Normalize", NormalizeImage(mean=mean, std=std, scale=scale, order=order)
 
     @register("ToCHWImage")
     def build_to_chw(self):

+ 41 - 24
paddlex/inference/models/text_detection/processors.py

@@ -19,7 +19,6 @@ from typing import Union
 import cv2
 import numpy as np
 import pyclipper
-from shapely.geometry import Polygon
 
 from ....utils import logging
 from ...utils.benchmark import benchmark
@@ -197,25 +196,39 @@ class DetResizeForTest:
 class NormalizeImage:
     """normalize image such as substract mean, divide std"""
 
-    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+    def __init__(self, scale=None, mean=None, std=None, order="chw"):
         super().__init__()
         if isinstance(scale, str):
             scale = eval(scale)
-        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        self.order = order
+
+        scale = scale if scale is not None else 1.0 / 255.0
         mean = mean if mean is not None else [0.485, 0.456, 0.406]
         std = std if std is not None else [0.229, 0.224, 0.225]
 
-        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
-        self.mean = np.array(mean).reshape(shape).astype("float32")
-        self.std = np.array(std).reshape(shape).astype("float32")
+        self.alpha = [scale / std[i] for i in range(len(std))]
+        self.beta = [-mean[i] / std[i] for i in range(len(std))]
 
     def __call__(self, imgs):
         """apply"""
 
-        def norm(img):
-            return (img.astype("float32") * self.scale - self.mean) / self.std
+        def _norm(img):
+            if self.order == "chw":
+                img = np.transpose(img, (2, 0, 1))
+
+            split_im = list(cv2.split(img))
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= self.alpha[c]
+                split_im[c] += self.beta[c]
+
+            res = cv2.merge(split_im)
 
-        return [norm(img) for img in imgs]
+            if self.order == "chw":
+                res = np.transpose(res, (1, 2, 0))
+            return res
+
+        return [_norm(img) for img in imgs]
 
 
 @benchmark.timeit
@@ -262,7 +275,8 @@ class DBPostProcess:
 
         bitmap = _bitmap
         height, width = bitmap.shape
-
+        width_scale = dest_width / width
+        height_scale = dest_height / height
         boxes = []
         scores = []
 
@@ -297,10 +311,10 @@ class DBPostProcess:
                 continue
 
             box = np.array(box)
-            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
-            box[:, 1] = np.clip(
-                np.round(box[:, 1] / height * dest_height), 0, dest_height
-            )
+            for i in range(box.shape[0]):
+                box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
+                box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
+
             boxes.append(box)
             scores.append(score)
         return boxes, scores
@@ -318,6 +332,8 @@ class DBPostProcess:
 
         bitmap = _bitmap
         height, width = bitmap.shape
+        width_scale = dest_width / width
+        height_scale = dest_height / height
 
         outs = cv2.findContours(
             (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
@@ -348,20 +364,21 @@ class DBPostProcess:
             box, sside = self.get_mini_boxes(box)
             if sside < self.min_size + 2:
                 continue
+
             box = np.array(box)
+            for i in range(box.shape[0]):
+                box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
+                box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
 
-            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
-            box[:, 1] = np.clip(
-                np.round(box[:, 1] / height * dest_height), 0, dest_height
-            )
             boxes.append(box.astype(np.int16))
             scores.append(score)
         return np.array(boxes, dtype=np.int16), scores
 
     def unclip(self, box, unclip_ratio):
         """unclip"""
-        poly = Polygon(box)
-        distance = poly.area * unclip_ratio / poly.length
+        area = cv2.contourArea(box)
+        length = cv2.arcLength(box, True)
+        distance = area * unclip_ratio / length
         offset = pyclipper.PyclipperOffset()
         offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
         try:
@@ -396,10 +413,10 @@ class DBPostProcess:
         """box_score_fast: use bbox mean score as the mean score"""
         h, w = bitmap.shape[:2]
         box = _box.copy()
-        xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
-        xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
-        ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
-        ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
+        xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
+        xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
+        ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
+        ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
 
         mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         box[:, 0] = box[:, 0] - xmin

+ 1 - 0
paddlex/inference/pipelines/components/__init__.py

@@ -21,6 +21,7 @@ from .common import (
     SortPolyBoxes,
     SortQuadBoxes,
     convert_points_to_boxes,
+    rotate_image,
 )
 from .faisser import FaissBuilder, FaissIndexer, IndexData
 from .prompt_engineering.base import BaseGeneratePrompt

+ 1 - 0
paddlex/inference/pipelines/components/common/__init__.py

@@ -16,3 +16,4 @@ from .base_result import BaseResult, CVResult
 from .convert_points_and_boxes import convert_points_to_boxes
 from .crop_image_regions import CropByBoxes, CropByPolys
 from .sort_boxes import SortPolyBoxes, SortQuadBoxes
+from .warp_image import rotate_image

+ 45 - 0
paddlex/inference/pipelines/components/common/warp_image.py

@@ -0,0 +1,45 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+
+
+def rotate_image(image, angle):
+    if angle < 0 or angle >= 360:
+        raise ValueError("`angle` should be in range [0, 360)")
+
+    if angle < 1e-7:
+        return image
+
+    # Should we align corners?
+    h, w = image.shape[:2]
+    center = (w / 2, h / 2)
+    scale = 1.0
+    mat = cv2.getRotationMatrix2D(center, angle, scale)
+    cos = np.abs(mat[0, 0])
+    sin = np.abs(mat[0, 1])
+    new_w = int((h * sin) + (w * cos))
+    new_h = int((h * cos) + (w * sin))
+    mat[0, 2] += (new_w - w) / 2
+    mat[1, 2] += (new_h - h) / 2
+    dst_size = (new_w, new_h)
+
+    rotated = cv2.warpAffine(
+        image,
+        mat,
+        dst_size,
+        flags=cv2.INTER_CUBIC,
+    )
+    return rotated

+ 2 - 21
paddlex/inference/pipelines/doc_preprocessor/pipeline.py

@@ -15,7 +15,6 @@
 from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
-from scipy.ndimage import rotate
 
 from ....utils import logging
 from ...common.batch_sampler import ImageBatchSampler
@@ -23,6 +22,7 @@ from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
+from ..components import rotate_image
 from .result import DocPreprocessorResult
 
 
@@ -77,25 +77,6 @@ class DocPreprocessorPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
-    def rotate_image(self, image_array: np.ndarray, rotate_angle: float) -> np.ndarray:
-        """
-        Rotate the given image array by the specified angle.
-
-        Args:
-            image_array (np.ndarray): The input image array to be rotated.
-            rotate_angle (float): The angle in degrees by which to rotate the image.
-
-        Returns:
-            np.ndarray: The rotated image array.
-
-        Raises:
-            AssertionError: If rotate_angle is not in the range [0, 360).
-        """
-        assert (
-            rotate_angle >= 0 and rotate_angle < 360
-        ), "rotate_angle must in [0-360), but get {rotate_angle}."
-        return rotate(image_array, rotate_angle, reshape=True)
-
     def check_model_settings_valid(self, model_settings: Dict) -> bool:
         """
         Check if the the input params for model settings are valid based on the initialized models.
@@ -178,7 +159,7 @@ class DocPreprocessorPipeline(BasePipeline):
             if model_settings["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
                 angle = int(pred["label_names"][0])
-                rot_img = self.rotate_image(image_array, angle)
+                rot_img = rotate_image(image_array, angle)
             else:
                 angle = -1
                 rot_img = image_array

+ 2 - 2
paddlex/inference/pipelines/ocr/pipeline.py

@@ -15,7 +15,6 @@
 from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
-from scipy.ndimage import rotate
 
 from ....utils import logging
 from ...common.batch_sampler import ImageBatchSampler
@@ -28,6 +27,7 @@ from ..components import (
     SortPolyBoxes,
     SortQuadBoxes,
     convert_points_to_boxes,
+    rotate_image,
 )
 from .result import OCRResult
 
@@ -163,7 +163,7 @@ class OCRPipeline(BasePipeline):
         for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
             # Convert 0/1 indicator to actual rotation angle
             rotate_angle = rotate_indicator * 180
-            rotated_image = rotate(image_array, rotate_angle, reshape=True)
+            rotated_image = rotate_image(image_array, rotate_angle)
             rotated_images.append(rotated_image)
 
         return rotated_images

+ 32 - 3
paddlex/inference/utils/benchmark.py

@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import copy
 import csv
 import functools
 import inspect
 import time
+import uuid
 from pathlib import Path
 from types import GeneratorType
 
@@ -34,6 +36,8 @@ ENTRY_POINT_NAME = "_entry_point_"
 # XXX: Global mutable state
 _inference_operations = []
 
+_is_measuring_time = False
+
 
 class Benchmark:
     def __init__(self, enabled):
@@ -65,7 +69,7 @@ class Benchmark:
                 source_line = inspect.getsourcelines(func)[1]
                 location = f"{source_file}:{source_line}"
             except (TypeError, OSError) as e:
-                location = "Unknown"
+                location = uuid.uuid4().hex
                 logging.debug(
                     f"Benchmark: failed to get source file and line number: {e}"
                 )
@@ -89,15 +93,27 @@ class Benchmark:
                         for k, v in kwargs.items()
                     }
                     output = func(*args, **kwargs)
+                    output = copy.deepcopy(output)
                     return output
 
             else:
 
                 @functools.wraps(func)
                 def _wrapper(*args, **kwargs):
+                    global _is_measuring_time
                     operation_name = f"{name}@{location}"
+                    if _is_measuring_time:
+                        raise RuntimeError(
+                            "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
+                        )
+                    if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
+                        _is_measuring_time = True
                     tic = time.perf_counter()
-                    output = func(*args, **kwargs)
+                    try:
+                        output = func(*args, **kwargs)
+                    finally:
+                        if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
+                            _is_measuring_time = False
                     if isinstance(output, GeneratorType):
                         return self.watch_generator(output, operation_name)
                     else:
@@ -118,10 +134,21 @@ class Benchmark:
     def watch_generator(self, generator, name):
         @functools.wraps(generator)
         def wrapper():
+            global _is_measuring_time
             while True:
                 try:
+                    if _is_measuring_time:
+                        raise RuntimeError(
+                            "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
+                        )
+                    if not name.startswith(f"{ENTRY_POINT_NAME}@"):
+                        _is_measuring_time = True
                     tic = time.perf_counter()
-                    item = next(generator)
+                    try:
+                        item = next(generator)
+                    finally:
+                        if not name.startswith(f"{ENTRY_POINT_NAME}@"):
+                            _is_measuring_time = False
                     self._update(time.perf_counter() - tic, name)
                     yield item
                 except StopIteration:
@@ -187,6 +214,8 @@ class Benchmark:
             avg = np.mean(time_list)
             operation_name = name.split("@")[0]
             location = name.split("@")[1]
+            if ":" not in location:
+                location = "Unknown"
             detail_list.append(
                 (iters, batch_size, instances, operation_name, avg, avg / batch_size)
             )