Просмотр исходного кода

Seal text det (#2043)

* add seal det

* add seal det

* add seal det

* add seal det

* add seal det

* Update PP-OCRv4_mobile_seal_det.yaml

* Update PP-OCRv4_server_seal_det.yaml

* Update support_model_list.md

* Update ocr.py

* Update text_det.py

* Update transforms.py

* Update text_det.py

* Update transforms.py

* Update transforms.py

* Update text_det.py

* ocr pipeline
Sunflower7788 1 год назад
Родитель
Сommit
984425190e

+ 6 - 2
README.md

@@ -117,12 +117,16 @@ PaddleX 3.0 覆盖了 16 条产业级模型产线,其中 9 条基础产线可
     <summary><b>more</b></summary><br/>Mask-RT-DETR-L<br/>Mask-RT-DETR-X<br/>Mask-RT-DETR-H<br/>SOLOv2<br/>MaskRCNN-ResNet50<br/>MaskRCNN-ResNet50-FPN<br/>MaskRCNN-ResNet50-vd-FPN<br/>MaskRCNN-ResNet50-vd-SSLDv2-FPN<br/>MaskRCNN-ResNet101-FPN<br/>MaskRCNN-ResNet101-vd-FPN<br/>MaskRCNN-ResNeXt101-vd-FPN</td>Cascade-MaskRCNN-ResNet50-FPN</td>Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN</td>PP-YOLOE_seg-S</td>
     <summary><b>more</b></summary><br/>Mask-RT-DETR-L<br/>Mask-RT-DETR-X<br/>Mask-RT-DETR-H<br/>SOLOv2<br/>MaskRCNN-ResNet50<br/>MaskRCNN-ResNet50-FPN<br/>MaskRCNN-ResNet50-vd-FPN<br/>MaskRCNN-ResNet50-vd-SSLDv2-FPN<br/>MaskRCNN-ResNet101-FPN<br/>MaskRCNN-ResNet101-vd-FPN<br/>MaskRCNN-ResNeXt101-vd-FPN</td>Cascade-MaskRCNN-ResNet50-FPN</td>Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN</td>PP-YOLOE_seg-S</td>
   </tr>
   </tr>
   <tr>
   <tr>
-    <td rowspan="3">基础产线</td>
-    <td rowspan="3">通用OCR</td>
+    <td rowspan="4">基础产线</td>
+    <td rowspan="4">通用OCR</td>
     <td>文本检测</td>
     <td>文本检测</td>
     <td>PP-OCRv4_mobile_det<br/>PP-OCRv4_server_det</td>
     <td>PP-OCRv4_mobile_det<br/>PP-OCRv4_server_det</td>
   </tr>
   </tr>
   <tr>
   <tr>
+    <td>印章文本检测</td>
+    <td>PP-OCRv4_mobile_seal_det<br/>PP-OCRv4_server_seal_det</td>
+  </tr>
+  <tr>
     <td>文本识别</td>
     <td>文本识别</td>
     <td>PP-OCRv4_mobile_rec<br/>PP-OCRv4_server_rec</td>
     <td>PP-OCRv4_mobile_rec<br/>PP-OCRv4_server_rec</td>
   </tr>
   </tr>

+ 7 - 1
docs/tutorials/models/support_model_list.md

@@ -264,11 +264,17 @@
 | :--- | :---: |
 | :--- | :---: |
 | SLANet | [SLANet.yaml](../../../paddlex/configs/table_recognition/SLANet.yaml)|
 | SLANet | [SLANet.yaml](../../../paddlex/configs/table_recognition/SLANet.yaml)|
 ## 六、文本检测
 ## 六、文本检测
-### 1.PP-OCRv4 系列
+### 1.PP-OCRv4常规文本检测系列
 | 模型名称 | config |
 | 模型名称 | config |
 | :--- | :---: |
 | :--- | :---: |
 | PP-OCRv4_server_det | [PP-OCRv4_server_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_server_det.yaml)|
 | PP-OCRv4_server_det | [PP-OCRv4_server_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_server_det.yaml)|
 | PP-OCRv4_mobile_det | [PP-OCRv4_mobile_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_mobile_det.yaml)|
 | PP-OCRv4_mobile_det | [PP-OCRv4_mobile_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_mobile_det.yaml)|
+
+### 1.PP-OCRv4 印章文本检测系列
+| 模型名称 | config |
+| :--- | :---: |
+| PP-OCRv4_server_seal_det | [PP-OCRv4_server_seal_det.yaml](../../../paddlex/configs/text_detection_seal/PP-OCRv4_server_seal_det.yaml)|
+| PP-OCRv4_mobile_seal_det | [PP-OCRv4_mobile_det.yaml](../../../paddlex/configs/text_detection_seal/PP-OCRv4_mobile_seal_det.yaml)|
 ## 七、文本识别
 ## 七、文本识别
 ### 1.PP-OCRv4 系列
 ### 1.PP-OCRv4 系列
 | 模型名称 | config |
 | 模型名称 | config |

+ 40 - 0
paddlex/configs/text_detection_seal/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv4_mobile_seal_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_curve_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_mobile_det_curve_trained.pdparams
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 40 - 0
paddlex/configs/text_detection_seal/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv4_server_seal_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_curve_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_det_curve_trained.pdparams
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 699 - 0
paddlex/inference/components/task_related/seal_det_warp.py

@@ -0,0 +1,699 @@
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+
+def Homography(image, img_points, world_width, world_height,
+               interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+    """
+    将图像透视变换到新的视角,返回变换后的图像。
+    
+    Args:
+        image (np.ndarray): 输入的图像,应为numpy数组类型。
+        img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+        world_width (int): 变换后图像在世界坐标系中的宽度。
+        world_height (int): 变换后图像在世界坐标系中的高度。
+        interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+        ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+        ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+    
+    Returns:
+        np.ndarray: 变换后的图像,为numpy数组类型。
+    
+    """
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop,
+                          pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation)
+
+    return dst_img
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(angle_x2)
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(angle_x2)
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(adjusted_points)
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx ** 2 + xdy ** 2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx ** 2 + ydy ** 2)
+            dy_list.append(d)
+
+        dx_list = [(dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2 for i in range(len(dx_list) // 2)]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5: # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = K[0, 0] * D0 * D0 + K[1, 1] * D1 * D1 + K[2, 2] * D2 * D2 + \
+                  (K[0, 1] + K[1, 0]) * D0 * D1 + (K[0, 2] + K[2, 0]) * D0 * D2 + (K[1, 2] + K[2, 1]) * D1 * D2
+        ratio_b = (K[0, 3] + K[3, 0]) * D0 + (K[1, 3] + K[3, 1]) * D1 + (K[2, 3] + K[3, 2]) * D2
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (2 * ratio_a[ratio_a != 0])
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2)
+        P[:, :, 1] = matT[1, 3] + t * (matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2)
+        P[:, :, 2] = matT[2, 3] + t * (matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2)
+
+        return P
+
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (1 + k4 * r2 + k5 * r4 + k6 * r6)
+        x2 = x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        y2 = y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+
+    def spatial_transform(self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(new_image_size, world_map, mtx, dist, rotation, tvecs)
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation)
+        return dst
+
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_ZERO_TANGENT_DIST | cv2.CALIB_FIX_ASPECT_RATIO
+        flag5 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL | cv2.CALIB_ZERO_TANGENT_DIST
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                image_coord.astype(np.float32),
+                                                                org_size,
+                                                                None,
+                                                                None,
+                                                                flags=flag)
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                   image_coord.astype(np.float32),
+                                                                   org_size,
+                                                                   None,
+                                                                   None,
+                                                                   flags=flag)
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+
+    def dc_homo(self, img, img_points, obj_points, is_horizontal_text, interpolation=cv2.INTER_LINEAR,
+                ratio_width=1.0, ratio_height=1.0):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i:(i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i:(i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(img, new_img_points, world_width, world_height,
+                                              interpolation=interpolation,
+                                              ratio_width=ratio_width, ratio_height=ratio_height)
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(np.uint8)
+
+        st = 0
+        for (homo_img, w, h) in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st:st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(self, image, img_points, world_width, world_height,
+                interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        将图像透视变换到新的视角,返回变换后的图像。
+        
+        Args:
+            image (np.ndarray): 输入的图像,应为numpy数组类型。
+            img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+            world_width (int): 变换后图像在世界坐标系中的宽度。
+            world_height (int): 变换后图像在世界坐标系中的高度。
+            interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+            ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+            ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+        
+        Returns:
+            np.ndarray: 变换后的图像,为numpy数组类型。
+        
+        """
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop,
+                            pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M, (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation)
+
+        return dst_img
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR, ratio_width=1.0, ratio_height=1.0, mode='calibration'):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(points)
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(points, org_size)
+
+        if mode.lower() == 'calibration':
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(org_size, image_coord, world_coord)
+
+            st_size = (int(new_image_size[0]*ratio_width), int(new_image_size[1]*ratio_height))
+            dst = self.spatial_transform(image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation)
+        elif mode.lower() == 'homography':
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(image_data, image_coord, world_coord, is_horizontal_text,
+                               interpolation=interpolation, ratio_width=1.0, ratio_height=1.0)
+        else:
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3])))
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2])))
+
+            dst_img = Homography(img, _points, img_crop_width, img_crop_height, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2)
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 0, 0), thickness=2)
+                else:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 255, 0), thickness=2)
+        return visualization
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR,
+                 ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1,2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(image_data, points, interpolation, ratio_width, ratio_height, mode)
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh,
+                                              square=True)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh, square=False)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        print('calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image'.format(loss))
+                        dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+                        print('here')
+            except Exception as e:
+                print(e)
+                dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+        else:
+            dst_img = self.get_rotate_crop_image(image_data, _points, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def run(self, image_data, points_list, interpolation=cv2.INTER_LINEAR,
+            ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError('ratio_width and ratio_height cannot be smaller than 1, but got {}', (ratio_width, ratio_height))
+
+        if mode.lower() != 'calibration' and mode.lower() != 'homography':
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        if mode.lower() == 'homography' and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError('ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})'.format(mode, ratio_width, ratio_height))
+
+        res = []
+        for points in points_list:
+            rectified_img = self(image_data, points, interpolation, ratio_width, ratio_height,
+                                 loss_thresh=loss_thresh, mode=mode)
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image
+

+ 338 - 8
paddlex/inference/components/task_related/text_det.py

@@ -20,12 +20,14 @@ import copy
 import math
 import math
 import pyclipper
 import pyclipper
 import numpy as np
 import numpy as np
+from numpy.linalg import norm
 from PIL import Image
 from PIL import Image
 from shapely.geometry import Polygon
 from shapely.geometry import Polygon
 
 
 from ...utils.io import ImageReader
 from ...utils.io import ImageReader
 from ....utils import logging
 from ....utils import logging
 from ..base import BaseComponent
 from ..base import BaseComponent
+from .seal_det_warp import AutoRectifier
 
 
 
 
 __all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
 __all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
@@ -430,19 +432,31 @@ class CropByPolys(BaseComponent):
     def apply(self, img_path, dt_polys):
     def apply(self, img_path, dt_polys):
         """apply"""
         """apply"""
         img = self._reader.read(img_path)
         img = self._reader.read(img_path)
-        dt_boxes = np.array(dt_polys)
+        
         # TODO
         # TODO
         # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
         # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
-        output_list = []
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.det_box_type == "quad":
-                img_crop = self.get_rotate_crop_image(img, tmp_box)
-            else:
+        if self.det_box_type == "quad":
+            dt_boxes = self.sorted_boxes(dt_polys)
+            dt_boxes = np.array(dt_boxes)
+            output_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
                 img_crop = self.get_minarea_rect_crop(img, tmp_box)
                 img_crop = self.get_minarea_rect_crop(img, tmp_box)
-            output_list.append(
+                output_list.append(
                 {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
                 {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
             )
             )
+        elif self.det_box_type == "poly":
+            output_list = []
+            dt_boxes = dt_polys
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
+                output_list.append(
+                {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
+                )
+        else:
+            raise NotImplementedError
+
         return output_list
         return output_list
 
 
     def sorted_boxes(self, dt_boxes):
     def sorted_boxes(self, dt_boxes):
@@ -537,3 +551,319 @@ class CropByPolys(BaseComponent):
         if dst_img_height * 1.0 / dst_img_width >= 1.5:
         if dst_img_height * 1.0 / dst_img_width >= 1.5:
             dst_img = np.rot90(dst_img)
             dst_img = np.rot90(dst_img)
         return dst_img
         return dst_img
+
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr=2.0             # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8)) 
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(
+                    self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+                                                        adjacent_edge_vec[1])
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(
+                    pad_points[1:] - poly_center, axis=-1),
+                norm(
+                    pad_points[:-1] - poly_center, axis=-1))
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+                (x - 0.5) / 0.5, 2.) / 2)
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+                    score) - 1)] * gaussian * 0.3
+
+            head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+                                                            score_matrix.shape)
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if vector_slope(points[1] - points[0]) + vector_slope(points[
+                    3] - points[2]) < vector_slope(points[2] - points[
+                        1]) + vector_slope(points[0] - points[3]):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+                vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+                    0]] - points[vertical_edge_inds[1][1]])
+            horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+                0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+                    horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
+                                                            [1]])
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while current_edge_ind + 1 < len(
+                    length_cumsum) and current_line_len >= length_cumsum[
+                        current_edge_ind + 1]:
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[
+                current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[
+                current_edge_ind]
+            current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
+                                                    - line[current_edge_ind]
+                                                    ) * end_shift_ratio
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i+1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []  
+        for indices in group:  
+            x_sum = 0  
+            y_sum = 0  
+            for index in indices:  
+                x, y = line[index]  
+                x_sum += x  
+                y_sum += y  
+            num_points = len(indices)  
+            mean_x = x_sum / num_points  
+            mean_y = y_sum / num_points  
+            mean_positions.append((mean_x, mean_y)) 
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        '''
+            修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+            args: img: 图片 ndarrary格式
+            points: polygon格式的多点坐标 N*2 shape, ndarray格式
+            return: 矫正后的图片 ndarray格式
+        '''
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG)+ 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(
+            resample_top_line, axis=0) - np.mean(
+                resample_bot_line, axis=0)
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,)*3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode='homography')
+        return img_crop[0]

+ 6 - 2
paddlex/inference/pipelines/ocr.py

@@ -23,13 +23,17 @@ class OCRPipeline(BasePipeline):
     entities = "ocr"
     entities = "ocr"
 
 
     def __init__(
     def __init__(
-        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, **kwargs
+        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, is_curve=False, **kwargs
     ):
     ):
         super().__init__(predictor_kwargs)
         super().__init__(predictor_kwargs)
         self._det_predict = self._create_predictor(det_model)
         self._det_predict = self._create_predictor(det_model)
         self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
         self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
         # TODO: foo
         # TODO: foo
-        self._crop_by_polys = CropByPolys(det_box_type="foo")
+        if is_curve:
+            det_box_type = 'poly'
+        else:
+            det_box_type = 'quad'
+        self._crop_by_polys = CropByPolys(det_box_type=det_box_type)
 
 
     def predict(self, x):
     def predict(self, x):
         for det_res in self._det_predict(x):
         for det_res in self._det_predict(x):

+ 4 - 0
paddlex/inference/predictors/official_models.py

@@ -183,6 +183,10 @@ PP-OCRv4_mobile_rec_infer.tar",
 PP-OCRv4_server_det_infer.tar",
 PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 PP-OCRv4_mobile_det_infer.tar",
 PP-OCRv4_mobile_det_infer.tar",
+    "PP-OCRv4_server_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_seal_det_infer.tar",
+    "PP-OCRv4_mobile_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_seal_det_infer.tar",
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 openatom_rec_repsvtr_ch_infer.tar",
 openatom_rec_repsvtr_ch_infer.tar",
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\

+ 45 - 13
paddlex/inference/results/ocr.py

@@ -31,13 +31,35 @@ class OCRResult(BaseResult):
         if len(self["dt_polys"]) == 0:
         if len(self["dt_polys"]) == 0:
             logging.warning("No text detected!")
             logging.warning("No text detected!")
 
 
+    def get_minarea_rect(self, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = np.array([points[index_a], points[index_b], points[index_c], points[index_d]]).astype(np.int32)
+
+        return box
+
     def _get_res_img(
     def _get_res_img(
         self,
         self,
         drop_score=0.5,
         drop_score=0.5,
         font_path=PINGFANG_FONT_FILE_PATH,
         font_path=PINGFANG_FONT_FILE_PATH,
     ):
     ):
         """draw ocr result"""
         """draw ocr result"""
-        boxes = np.array(self["dt_polys"])
+        boxes = self["dt_polys"]
         txts = self["rec_text"]
         txts = self["rec_text"]
         scores = self["rec_score"]
         scores = self["rec_score"]
         img = self._img_reader.read(self["img_path"])
         img = self._img_reader.read(self["img_path"])
@@ -46,23 +68,33 @@ class OCRResult(BaseResult):
         img_left = image.copy()
         img_left = image.copy()
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
         random.seed(0)
         random.seed(0)
-
         draw_left = ImageDraw.Draw(img_left)
         draw_left = ImageDraw.Draw(img_left)
         if txts is None or len(txts) != len(boxes):
         if txts is None or len(txts) != len(boxes):
             txts = [None] * len(boxes)
             txts = [None] * len(boxes)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
-            if scores is not None and scores[idx] < drop_score:
+            try:
+                if scores is not None and scores[idx] < drop_score:
+                    continue
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:,1]) - min(box[:,1])))
+                    box[:2,1] = np.mean(box[:,1])
+                    box[2:,1] = np.mean(box[:,1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except:
                 continue
                 continue
-            color = (
-                random.randint(0, 255),
-                random.randint(0, 255),
-                random.randint(0, 255),
-            )
-            draw_left.polygon(box, fill=color)
-            img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
-            pts = np.array(box, np.int32).reshape((-1, 1, 2))
-            cv2.polylines(img_right_text, [pts], True, color, 1)
-            img_right = cv2.bitwise_and(img_right, img_right_text)
         img_left = Image.blend(image, img_left, 0.5)
         img_left = Image.blend(image, img_left, 0.5)
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
         img_show.paste(img_left, (0, 0, w, h))
         img_show.paste(img_left, (0, 0, w, h))

+ 3 - 3
paddlex/inference/results/text_det.py

@@ -23,10 +23,10 @@ class TextDetResult(BaseResult):
 
 
     def _get_res_img(self):
     def _get_res_img(self):
         """draw rectangle"""
         """draw rectangle"""
-        boxes = np.array(self["dt_polys"])
+        boxes = self["dt_polys"]
         img = self._img_reader.read(self["img_path"])
         img = self._img_reader.read(self["img_path"])
         res_img = img.copy()
         res_img = img.copy()
-        for box in boxes.astype(int):
-            box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+        for box in boxes:
+            box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64)
             cv2.polylines(res_img, [box], True, (0, 0, 255), 2)
             cv2.polylines(res_img, [box], True, (0, 0, 255), 2)
         return res_img
         return res_img

+ 4 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -196,6 +196,10 @@ PP-OCRv4_mobile_rec_infer.tar",
 PP-OCRv4_server_det_infer.tar",
 PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
 PP-OCRv4_mobile_det_infer.tar",
 PP-OCRv4_mobile_det_infer.tar",
+    "PP-OCRv4_server_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_seal_det_infer.tar",
+    "PP-OCRv4_mobile_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_seal_det_infer.tar",
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
 openatom_rec_repsvtr_ch_infer.tar",
 openatom_rec_repsvtr_ch_infer.tar",
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\

+ 2 - 0
paddlex/modules/text_detection/model_list.py

@@ -15,4 +15,6 @@
 MODELS = [
 MODELS = [
     "PP-OCRv4_mobile_det",
     "PP-OCRv4_mobile_det",
     "PP-OCRv4_server_det",
     "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_seal_det",
+    "PP-OCRv4_server_seal_det"
 ]
 ]

+ 34 - 13
paddlex/modules/text_detection/predictor/predictor.py

@@ -60,9 +60,15 @@ class TextDetPredictor(BasePredictor):
 
 
     def _get_pre_transforms_from_config(self):
     def _get_pre_transforms_from_config(self):
         """get preprocess transforms"""
         """get preprocess transforms"""
+
+        if self.model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            limit_side_len = 736
+        else:
+            limit_side_len = 960
+    
         return [
         return [
             image_common.ReadImage(),
             image_common.ReadImage(),
-            T.DetResizeForTest(limit_side_len=960, limit_type="max"),
+            T.DetResizeForTest(limit_side_len=limit_side_len, limit_type="max"),
             T.NormalizeImage(
             T.NormalizeImage(
                 mean=[0.485, 0.456, 0.406],
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
                 std=[0.229, 0.224, 0.225],
@@ -74,21 +80,36 @@ class TextDetPredictor(BasePredictor):
 
 
     def _get_post_transforms_from_config(self):
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
         """get postprocess transforms"""
-        post_transforms = [
-            T.DBPostProcess(
-                thresh=0.3,
-                box_thresh=0.6,
-                max_candidates=1000,
-                unclip_ratio=1.5,
-                use_dilation=False,
-                score_mode="fast",
-                box_type="quad",
-            )
-        ]
+        if self.model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            task = 'poly'
+            post_transforms = [
+                T.DBPostProcess(
+                    thresh=0.2,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="poly",
+                )
+            ]
+        else:
+            task = 'quad'
+            post_transforms = [
+                T.DBPostProcess(
+                    thresh=0.3,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="quad",
+                )
+            ]
         if not self.disable_print:
         if not self.disable_print:
             post_transforms.append(T.PrintResult())
             post_transforms.append(T.PrintResult())
         if not self.disable_save:
         if not self.disable_save:
             post_transforms.append(
             post_transforms.append(
-                T.SaveTextDetResults(self.output),
+                T.SaveTextDetResults(self.output, task),
             )
             )
         return post_transforms
         return post_transforms

+ 353 - 12
paddlex/modules/text_detection/predictor/transforms.py

@@ -20,6 +20,7 @@ import copy
 import math
 import math
 import pyclipper
 import pyclipper
 import numpy as np
 import numpy as np
+from numpy.linalg import norm
 from PIL import Image
 from PIL import Image
 from shapely.geometry import Polygon
 from shapely.geometry import Polygon
 
 
@@ -28,6 +29,7 @@ from ...base.predictor.io.writers import ImageWriter
 from ...base.predictor.io.readers import ImageReader
 from ...base.predictor.io.readers import ImageReader
 from ...base.predictor import BaseTransform
 from ...base.predictor import BaseTransform
 from .keys import TextDetKeys as K
 from .keys import TextDetKeys as K
+from .utils import AutoRectifier
 
 
 __all__ = [
 __all__ = [
     "DetResizeForTest",
     "DetResizeForTest",
@@ -461,17 +463,23 @@ class CropByPolys(BaseTransform):
     def apply(self, data):
     def apply(self, data):
         """apply"""
         """apply"""
         ori_im = data[K.ORI_IM]
         ori_im = data[K.ORI_IM]
-        # TODO
-        # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
-        dt_boxes = np.array(data[K.DT_POLYS])
-        img_crop_list = []
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.det_box_type == "quad":
-                img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
-            else:
+        if self.det_box_type == "quad":
+            dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
+            dt_boxes = np.array(dt_boxes)
+            img_crop_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
                 img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
                 img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
+                img_crop_list.append(img_crop)
+        elif self.det_box_type == "poly":
+            img_crop_list = []
+            dt_boxes = data[K.DT_POLYS]
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(ori_im.copy(), tmp_box)
+                img_crop_list.append(img_crop)
+        else:
+            raise NotImplementedError
         data[K.SUB_IMGS] = img_crop_list
         data[K.SUB_IMGS] = img_crop_list
         return data
         return data
 
 
@@ -533,6 +541,7 @@ class CropByPolys(BaseTransform):
         crop_img = self.get_rotate_crop_image(img, np.array(box))
         crop_img = self.get_rotate_crop_image(img, np.array(box))
         return crop_img
         return crop_img
 
 
+
     def get_rotate_crop_image(self, img, points):
     def get_rotate_crop_image(self, img, points):
         """
         """
         img_height, img_width = img.shape[0:2]
         img_height, img_width = img.shape[0:2]
@@ -578,13 +587,330 @@ class CropByPolys(BaseTransform):
             dst_img = np.rot90(dst_img)
             dst_img = np.rot90(dst_img)
         return dst_img
         return dst_img
 
 
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr=2.0             # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8)) 
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(
+                    self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+                                                        adjacent_edge_vec[1])
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(
+                    pad_points[1:] - poly_center, axis=-1),
+                norm(
+                    pad_points[:-1] - poly_center, axis=-1))
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+                (x - 0.5) / 0.5, 2.) / 2)
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+                    score) - 1)] * gaussian * 0.3
+
+            head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+                                                            score_matrix.shape)
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if vector_slope(points[1] - points[0]) + vector_slope(points[
+                    3] - points[2]) < vector_slope(points[2] - points[
+                        1]) + vector_slope(points[0] - points[3]):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+                vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+                    0]] - points[vertical_edge_inds[1][1]])
+            horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+                0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+                    horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
+                                                            [1]])
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while current_edge_ind + 1 < len(
+                    length_cumsum) and current_line_len >= length_cumsum[
+                        current_edge_ind + 1]:
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[
+                current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[
+                current_edge_ind]
+            current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
+                                                    - line[current_edge_ind]
+                                                    ) * end_shift_ratio
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i+1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []  
+        for indices in group:  
+            x_sum = 0  
+            y_sum = 0  
+            for index in indices:  
+                x, y = line[index]  
+                x_sum += x  
+                y_sum += y  
+            num_points = len(indices)  
+            mean_x = x_sum / num_points  
+            mean_y = y_sum / num_points  
+            mean_positions.append((mean_x, mean_y)) 
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        '''
+            修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+            args: img: 图片 ndarrary格式
+            points: polygon格式的多点坐标 N*2 shape, ndarray格式
+            return: 矫正后的图片 ndarray格式
+        '''
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG)+ 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(
+            resample_top_line, axis=0) - np.mean(
+                resample_bot_line, axis=0)
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,)*3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode='homography')
+        return img_crop[0]
+
 
 
 class SaveTextDetResults(BaseTransform):
 class SaveTextDetResults(BaseTransform):
     """Save Text Det Results"""
     """Save Text Det Results"""
 
 
-    def __init__(self, save_dir):
+    def __init__(self, save_dir, task='quad'):
         super().__init__()
         super().__init__()
         self.save_dir = save_dir
         self.save_dir = save_dir
+        self.task = task
         # We use pillow backend to save both numpy arrays and PIL Image objects
         # We use pillow backend to save both numpy arrays and PIL Image objects
         self._writer = ImageWriter(backend="opencv")
         self._writer = ImageWriter(backend="opencv")
 
 
@@ -598,7 +924,10 @@ class SaveTextDetResults(BaseTransform):
         fn = os.path.basename(data["input_path"])
         fn = os.path.basename(data["input_path"])
         save_path = os.path.join(self.save_dir, fn)
         save_path = os.path.join(self.save_dir, fn)
         bbox_res = data[K.DT_POLYS]
         bbox_res = data[K.DT_POLYS]
-        vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
+        if self.task == "quad":
+            vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
+        else:
+            vis_img = self.draw_polyline(data[K.IM_PATH], bbox_res)
         self._writer.write(save_path, vis_img)
         self._writer.write(save_path, vis_img)
         return data
         return data
 
 
@@ -621,6 +950,16 @@ class SaveTextDetResults(BaseTransform):
             box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
             box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
             cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
             cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
         return img_show
         return img_show
+    
+    def draw_polyline(self, img_path, boxes):
+        """draw polyline"""
+        img = cv2.imread(img_path)
+        img_show = img.copy()
+        for box in boxes:
+            box = np.array(box).astype(int)
+            box = np.reshape(box, [-1, 1, 2]).astype(np.int64)
+            cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
+        return img_show
 
 
 
 
 class PrintResult(BaseTransform):
 class PrintResult(BaseTransform):
@@ -644,3 +983,5 @@ class PrintResult(BaseTransform):
 
 
     # DT_SCORES = 'dt_scores'
     # DT_SCORES = 'dt_scores'
     # DT_POLYS = 'dt_polys'
     # DT_POLYS = 'dt_polys'
+
+

+ 698 - 0
paddlex/modules/text_detection/predictor/utils.py

@@ -0,0 +1,698 @@
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+def Homography(image, img_points, world_width, world_height,
+               interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+    """
+    将图像透视变换到新的视角,返回变换后的图像。
+    
+    Args:
+        image (np.ndarray): 输入的图像,应为numpy数组类型。
+        img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+        world_width (int): 变换后图像在世界坐标系中的宽度。
+        world_height (int): 变换后图像在世界坐标系中的高度。
+        interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+        ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+        ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+    
+    Returns:
+        np.ndarray: 变换后的图像,为numpy数组类型。
+    
+    """
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop,
+                          pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation)
+
+    return dst_img
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(angle_x2)
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(angle_x2)
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(adjusted_points)
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx ** 2 + xdy ** 2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx ** 2 + ydy ** 2)
+            dy_list.append(d)
+
+        dx_list = [(dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2 for i in range(len(dx_list) // 2)]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5: # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = K[0, 0] * D0 * D0 + K[1, 1] * D1 * D1 + K[2, 2] * D2 * D2 + \
+                  (K[0, 1] + K[1, 0]) * D0 * D1 + (K[0, 2] + K[2, 0]) * D0 * D2 + (K[1, 2] + K[2, 1]) * D1 * D2
+        ratio_b = (K[0, 3] + K[3, 0]) * D0 + (K[1, 3] + K[3, 1]) * D1 + (K[2, 3] + K[3, 2]) * D2
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (2 * ratio_a[ratio_a != 0])
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2)
+        P[:, :, 1] = matT[1, 3] + t * (matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2)
+        P[:, :, 2] = matT[2, 3] + t * (matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2)
+
+        return P
+
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (1 + k4 * r2 + k5 * r4 + k6 * r6)
+        x2 = x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        y2 = y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+
+    def spatial_transform(self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(new_image_size, world_map, mtx, dist, rotation, tvecs)
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation)
+        return dst
+
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_ZERO_TANGENT_DIST | cv2.CALIB_FIX_ASPECT_RATIO
+        flag5 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL | cv2.CALIB_ZERO_TANGENT_DIST
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                image_coord.astype(np.float32),
+                                                                org_size,
+                                                                None,
+                                                                None,
+                                                                flags=flag)
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                   image_coord.astype(np.float32),
+                                                                   org_size,
+                                                                   None,
+                                                                   None,
+                                                                   flags=flag)
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+
+    def dc_homo(self, img, img_points, obj_points, is_horizontal_text, interpolation=cv2.INTER_LINEAR,
+                ratio_width=1.0, ratio_height=1.0):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i:(i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i:(i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(img, new_img_points, world_width, world_height,
+                                              interpolation=interpolation,
+                                              ratio_width=ratio_width, ratio_height=ratio_height)
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(np.uint8)
+
+        st = 0
+        for (homo_img, w, h) in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st:st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(self, image, img_points, world_width, world_height,
+                interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        将图像透视变换到新的视角,返回变换后的图像。
+        
+        Args:
+            image (np.ndarray): 输入的图像,应为numpy数组类型。
+            img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+            world_width (int): 变换后图像在世界坐标系中的宽度。
+            world_height (int): 变换后图像在世界坐标系中的高度。
+            interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+            ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+            ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+        
+        Returns:
+            np.ndarray: 变换后的图像,为numpy数组类型。
+        
+        """
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop,
+                            pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M, (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation)
+
+        return dst_img
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR, ratio_width=1.0, ratio_height=1.0, mode='calibration'):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(points)
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(points, org_size)
+
+        if mode.lower() == 'calibration':
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(org_size, image_coord, world_coord)
+
+            st_size = (int(new_image_size[0]*ratio_width), int(new_image_size[1]*ratio_height))
+            dst = self.spatial_transform(image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation)
+        elif mode.lower() == 'homography':
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(image_data, image_coord, world_coord, is_horizontal_text,
+                               interpolation=interpolation, ratio_width=1.0, ratio_height=1.0)
+        else:
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3])))
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2])))
+
+            dst_img = Homography(img, _points, img_crop_width, img_crop_height, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2)
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 0, 0), thickness=2)
+                else:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 255, 0), thickness=2)
+        return visualization
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR,
+                 ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1,2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(image_data, points, interpolation, ratio_width, ratio_height, mode)
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh,
+                                              square=True)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh, square=False)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        print('calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image'.format(loss))
+                        dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+                        print('here')
+            except Exception as e:
+                print(e)
+                dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+        else:
+            dst_img = self.get_rotate_crop_image(image_data, _points, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def run(self, image_data, points_list, interpolation=cv2.INTER_LINEAR,
+            ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError('ratio_width and ratio_height cannot be smaller than 1, but got {}', (ratio_width, ratio_height))
+
+        if mode.lower() != 'calibration' and mode.lower() != 'homography':
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        if mode.lower() == 'homography' and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError('ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})'.format(mode, ratio_width, ratio_height))
+
+        res = []
+        for points in points_list:
+            rectified_img = self(image_data, points, interpolation, ratio_width, ratio_height,
+                                 loss_thresh=loss_thresh, mode=mode)
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image
+

+ 33 - 13
paddlex/pipelines/OCR/pipeline.py

@@ -49,6 +49,11 @@ class OCRPipeline(BasePipeline):
         self.device = device
         self.device = device
         self.text_det_kernel_option = text_det_kernel_option
         self.text_det_kernel_option = text_det_kernel_option
         self.text_rec_kernel_option = text_rec_kernel_option
         self.text_rec_kernel_option = text_rec_kernel_option
+        if self.text_det_model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            self.task = "poly"
+        else:
+            self.task = "quad"
+        
         if (
         if (
             self.text_det_model_name is not None
             self.text_det_model_name is not None
             and self.text_rec_model_name is not None
             and self.text_rec_model_name is not None
@@ -80,19 +85,34 @@ Only support: {text_rec_models}."
             if self.text_rec_kernel_option is None
             if self.text_rec_kernel_option is None
             else self.text_rec_kernel_option
             else self.text_rec_kernel_option
         )
         )
-        text_det_post_transforms = [
-            text_det_T.DBPostProcess(
-                thresh=0.3,
-                box_thresh=0.6,
-                max_candidates=1000,
-                unclip_ratio=1.5,
-                use_dilation=False,
-                score_mode="fast",
-                box_type="quad",
-            ),
-            # TODO
-            text_det_T.CropByPolys(det_box_type="foo"),
-        ]
+        if self.task == "poly":
+            text_det_post_transforms = [
+                text_det_T.DBPostProcess(
+                    thresh=0.2,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="poly",
+                ),
+                # TODO
+                text_det_T.CropByPolys(det_box_type="poly"),
+            ]
+        else:
+            text_det_post_transforms = [
+                text_det_T.DBPostProcess(
+                    thresh=0.3,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="quad",
+                ),
+                # TODO
+                text_det_T.CropByPolys(det_box_type="quad"),
+            ]
 
 
         self.text_det_model = create_model(
         self.text_det_model = create_model(
             self.text_det_model_name,
             self.text_det_model_name,

+ 41 - 7
paddlex/pipelines/OCR/utils.py

@@ -24,6 +24,28 @@ import copy
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 
 
 
 
+def get_minarea_rect(points):
+    bounding_box = cv2.minAreaRect(points)
+    points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+    index_a, index_b, index_c, index_d = 0, 1, 2, 3
+    if points[1][1] > points[0][1]:
+        index_a = 0
+        index_d = 1
+    else:
+        index_a = 1
+        index_d = 0
+    if points[3][1] > points[2][1]:
+        index_b = 2
+        index_c = 3
+    else:
+        index_b = 3
+        index_c = 2
+
+    box = np.array([points[index_a], points[index_b], points[index_c], points[index_d]]).astype(np.int32)
+
+    return box
+
 def draw_ocr_box_txt(
 def draw_ocr_box_txt(
     img,
     img,
     boxes,
     boxes,
@@ -43,14 +65,26 @@ def draw_ocr_box_txt(
     if txts is None or len(txts) != len(boxes):
     if txts is None or len(txts) != len(boxes):
         txts = [None] * len(boxes)
         txts = [None] * len(boxes)
     for idx, (box, txt) in enumerate(zip(boxes, txts)):
     for idx, (box, txt) in enumerate(zip(boxes, txts)):
-        if scores is not None and scores[idx] < drop_score:
+        try:
+            if scores is not None and scores[idx] < drop_score:
+                continue
+            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+            box = np.array(box)
+            if len(box) > 4:
+                pts = [(x, y) for x, y in box.tolist()]
+                draw_left.polygon(pts, outline=color, width=8)
+                box = get_minarea_rect(box)
+                height = int(0.5 * (max(box[:,1]) - min(box[:,1])))
+                box[:2,1] = np.mean(box[:,1])
+                box[2:,1] = np.mean(box[:,1]) + min(20, height)
+            draw_left.polygon(box, fill=color)
+            img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+            pts = np.array(box, np.int32).reshape((-1, 1, 2))
+            cv2.polylines(img_right_text, [pts], True, color, 1)
+            img_right = cv2.bitwise_and(img_right, img_right_text)
+        except:
             continue
             continue
-        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
-        draw_left.polygon(box, fill=color)
-        img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
-        pts = np.array(box, np.int32).reshape((-1, 1, 2))
-        cv2.polylines(img_right_text, [pts], True, color, 1)
-        img_right = cv2.bitwise_and(img_right, img_right_text)
+
     img_left = Image.blend(image, img_left, 0.5)
     img_left = Image.blend(image, img_left, 0.5)
     img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
     img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
     img_show.paste(img_left, (0, 0, w, h))
     img_show.paste(img_left, (0, 0, w, h))

+ 169 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,169 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 100
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: output
+  save_epoch_step: 1
+  eval_batch_step:
+  - 0
+  - 100
+  cal_metric_during_train: false
+  checkpoints:
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_mobile_det_curve_trained.pdparams
+  save_inference_dir: null
+  use_visualdl: false
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform: null
+  Backbone:
+    name: PPLCNetV3
+    scale: 0.75
+    det: True
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 1e-6
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.2
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 0.5
+  box_type: "poly"
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 640
+        - 640
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.8
+        thresh_min: 0.3
+        thresh_max: 0.7
+        total_epoch: 500
+    - MakeShrinkMap:
+        shrink_ratio: 0.8
+        min_text_size: 8
+        total_epoch: 500
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 3
+
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 736
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 0
+profiler_options: null

+ 169 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,169 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 100
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: output
+  save_epoch_step: 1
+  eval_batch_step:
+  - 0
+  - 100
+  cal_metric_during_train: false
+  checkpoints:
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_det_curve_trained.pdparams
+  save_inference_dir: null
+  use_visualdl: false
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform: null
+  Backbone:
+    name: PPHGNet_small
+    det: True
+  Neck:
+    name: LKPAN
+    out_channels: 256
+    intracl: true
+  Head:
+    name: PFHeadLocal
+    k: 50
+    mode: "large"
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 1e-6
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.2
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 0.5
+  box_type: "poly"
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 640
+        - 640
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.8
+        thresh_min: 0.3
+        thresh_max: 0.7
+        total_epoch: 500
+    - MakeShrinkMap:
+        shrink_ratio: 0.8
+        min_text_size: 8
+        total_epoch: 500
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 4
+    num_workers: 3
+
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 736
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 0
+profiler_options: null

+ 21 - 1
paddlex/repo_apis/PaddleOCR_api/text_det/register.py

@@ -40,7 +40,7 @@ register_model_info(
     {
     {
         "model_name": "PP-OCRv4_mobile_det",
         "model_name": "PP-OCRv4_mobile_det",
         "suite": "TextDet",
         "suite": "TextDet",
-        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_det.yaml"),
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_seal_det.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_apis": ["train", "evaluate", "predict", "export"],
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_mobile_det.yaml",
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_mobile_det.yaml",
     }
     }
@@ -55,3 +55,23 @@ register_model_info(
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_server_det.yaml",
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_server_det.yaml",
     }
     }
 )
 )
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv4_server_seal_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_server_seal_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_server_seal_det.yaml",
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv4_mobile_seal_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_seal_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_mobile_seal_det.yaml",
+    }
+)

+ 35 - 0
paddlex/utils/hpi_configs/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,35 @@
+Hpi:
+  backend_config:
+    onnx_runtime:
+      cpu_num_threads: 8
+    openvino:
+      cpu_num_threads: 8
+    paddle_infer:
+      cpu_num_threads: 8
+      enable_log_info: false
+    paddle_tensorrt:
+      dynamic_shapes:
+        enable_log_info: false
+        x:
+        - []
+        - []
+        - []
+    tensorrt:
+      dynamic_shapes:
+        x:
+        - []
+        - []
+        - []
+  selected_backends:
+    cpu: onnx_runtime
+    gpu: paddle_tensorrt
+  supported_backends:
+    cpu:
+    - paddle_infer
+    - openvino
+    - onnx_runtime
+    gpu:
+    - paddle_infer
+    - paddle_tensorrt
+    - onnx_runtime
+    - tensorrt

+ 35 - 0
paddlex/utils/hpi_configs/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,35 @@
+Hpi:
+  backend_config:
+    onnx_runtime:
+      cpu_num_threads: 8
+    openvino:
+      cpu_num_threads: 8
+    paddle_infer:
+      cpu_num_threads: 8
+      enable_log_info: false
+    paddle_tensorrt:
+      dynamic_shapes:
+        enable_log_info: false
+        x:
+        - []
+        - []
+        - []
+    tensorrt:
+      dynamic_shapes:
+        x:
+        - []
+        - []
+        - []
+  selected_backends:
+    cpu: onnx_runtime
+    gpu: paddle_tensorrt
+  supported_backends:
+    cpu:
+    - paddle_infer
+    - openvino
+    - onnx_runtime
+    gpu:
+    - paddle_infer
+    - paddle_tensorrt
+    - onnx_runtime
+    - tensorrt