Przeglądaj źródła

support layout parsing pipeline

zhouchangda 1 rok temu
rodzic
commit
2cc62ce417

Plik diff jest za duży
+ 246 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/layout_parsing.md


+ 1 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/layout_parsing_en.md

@@ -0,0 +1 @@
+简体中文 | [English](layout_parsing_en.md)

+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -36,6 +36,7 @@ from .formula_recognition import FormulaRecognitionPipeline
 from .table_recognition import TableRecPipeline
 from .seal_recognition import SealOCRPipeline
 from .ppchatocrv3 import PPChatOCRPipeline
+from .layout_parsing import LayoutParsingPipeline
 
 
 def load_pipeline_config(pipeline: str) -> Dict[str, Any]:

+ 15 - 0
paddlex/inference/pipelines/layout_parsing/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .layout_parsing import LayoutParsingPipeline

+ 361 - 0
paddlex/inference/pipelines/layout_parsing/layout_parsing.py

@@ -0,0 +1,361 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from ...results import *
+from ...components import *
+from ..ocr import OCRPipeline
+from ....utils import logging
+from ..ppchatocrv3.utils import *
+from ..table_recognition import _TableRecPipeline
+from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
+
+
+class LayoutParsingPipeline(_TableRecPipeline):
+    """Layout Analysis Pileline"""
+
+    entities = "layout_parsing"
+
+    def __init__(
+        self,
+        layout_model,
+        text_det_model,
+        text_rec_model,
+        table_model,
+        formula_rec_model,
+        doc_image_ori_cls_model=None,
+        doc_image_unwarp_model=None,
+        seal_text_det_model=None,
+        layout_batch_size=1,
+        text_det_batch_size=1,
+        text_rec_batch_size=1,
+        table_batch_size=1,
+        doc_image_ori_cls_batch_size=1,
+        doc_image_unwarp_batch_size=1,
+        seal_text_det_batch_size=1,
+        formula_rec_batch_size=1,
+        recovery=True,
+        device=None,
+        predictor_kwargs=None,
+    ):
+        super().__init__(
+            device,
+            predictor_kwargs,
+        )
+        self._build_predictor(
+            layout_model=layout_model,
+            text_det_model=text_det_model,
+            text_rec_model=text_rec_model,
+            table_model=table_model,
+            doc_image_ori_cls_model=doc_image_ori_cls_model,
+            doc_image_unwarp_model=doc_image_unwarp_model,
+            seal_text_det_model=seal_text_det_model,
+            formula_rec_model=formula_rec_model,
+        )
+        self.set_predictor(
+            layout_batch_size=layout_batch_size,
+            text_det_batch_size=text_det_batch_size,
+            text_rec_batch_size=text_rec_batch_size,
+            table_batch_size=table_batch_size,
+            doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
+            doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
+            seal_text_det_batch_size=seal_text_det_batch_size,
+            formula_rec_batch_size=formula_rec_batch_size,
+        )
+        self.recovery = recovery
+
+    def _build_predictor(
+        self,
+        layout_model,
+        text_det_model,
+        text_rec_model,
+        table_model,
+        formula_rec_model,
+        seal_text_det_model=None,
+        doc_image_ori_cls_model=None,
+        doc_image_unwarp_model=None,
+    ):
+        super()._build_predictor(
+            layout_model, text_det_model, text_rec_model, table_model
+        )
+
+        self.formula_predictor = self._create(formula_rec_model)
+
+        if seal_text_det_model:
+            self.curve_pipeline = self._create(
+                pipeline=OCRPipeline,
+                text_det_model=seal_text_det_model,
+                text_rec_model=text_rec_model,
+            )
+        else:
+            self.curve_pipeline = None
+        if doc_image_ori_cls_model:
+            self.oricls_predictor = self._create(doc_image_ori_cls_model)
+        else:
+            self.oricls_predictor = None
+        if doc_image_unwarp_model:
+            self.uvdoc_predictor = self._create(doc_image_unwarp_model)
+        else:
+            self.uvdoc_predictor = None
+
+        self.img_reader = ReadImage(format="RGB")
+        self.cropper = CropByBoxes()
+
+    def set_predictor(
+        self,
+        layout_batch_size=None,
+        text_det_batch_size=None,
+        text_rec_batch_size=None,
+        table_batch_size=None,
+        doc_image_ori_cls_batch_size=None,
+        doc_image_unwarp_batch_size=None,
+        seal_text_det_batch_size=None,
+        formula_rec_batch_size=None,
+        device=None,
+    ):
+        if text_det_batch_size and text_det_batch_size > 1:
+            logging.warning(
+                f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
+            )
+        if layout_batch_size:
+            self.layout_predictor.set_predictor(batch_size=layout_batch_size)
+        if text_rec_batch_size:
+            self.ocr_pipeline.text_rec_model.set_predictor(
+                batch_size=text_rec_batch_size
+            )
+        if table_batch_size:
+            self.table_predictor.set_predictor(batch_size=table_batch_size)
+        if formula_rec_batch_size:
+            self.formula_predictor.set_predictor(batch_size=formula_rec_batch_size)
+        if self.curve_pipeline and seal_text_det_batch_size:
+            self.curve_pipeline.text_det_model.set_predictor(
+                batch_size=seal_text_det_batch_size
+            )
+        if self.oricls_predictor and doc_image_ori_cls_batch_size:
+            self.oricls_predictor.set_predictor(batch_size=doc_image_ori_cls_batch_size)
+        if self.uvdoc_predictor and doc_image_unwarp_batch_size:
+            self.uvdoc_predictor.set_predictor(batch_size=doc_image_unwarp_batch_size)
+
+        if device:
+            if self.curve_pipeline:
+                self.curve_pipeline.set_predictor(device=device)
+            if self.oricls_predictor:
+                self.oricls_predictor.set_predictor(device=device)
+            if self.uvdoc_predictor:
+                self.uvdoc_predictor.set_predictor(device=device)
+            self.layout_predictor.set_predictor(device=device)
+            self.ocr_pipeline.set_predictor(device=device)
+
+    def predict(
+        self,
+        inputs,
+        use_doc_image_ori_cls_model=True,
+        use_doc_image_unwarp_model=True,
+        use_seal_text_det_model=True,
+        recovery=True,
+        **kwargs,
+    ):
+        self.set_predictor(**kwargs)
+        # get oricls and uvdoc results
+        img_info_list = list(self.img_reader(inputs))[0]
+        oricls_results = []
+        if self.oricls_predictor and use_doc_image_ori_cls_model:
+            oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
+        unwarp_result = []
+        if self.uvdoc_predictor and use_doc_image_unwarp_model:
+            unwarp_result = get_unwarp_results(img_info_list, self.uvdoc_predictor)
+        img_list = [img_info["img"] for img_info in img_info_list]
+        for idx, (img_info, layout_pred) in enumerate(
+            zip(img_info_list, self.layout_predictor(img_list))
+        ):
+            single_img_res = {
+                "input_path": "",
+                "layout_result": DetResult({}),
+                "ocr_result": OCRResult({}),
+                "table_ocr_result": [],
+                "table_result": StructureTableResult([]),
+                "layout_parsing_result": [],
+                "oricls_result": TopkResult({}),
+                "formula_result": TextRecResult({}),
+                "unwarp_result": DocTrResult({}),
+                "curve_result": [],
+            }
+            # update oricls and uvdoc result
+            if oricls_results:
+                single_img_res["oricls_result"] = oricls_results[idx]
+            if unwarp_result:
+                single_img_res["unwarp_result"] = unwarp_result[idx]
+            # update layout result
+            single_img_res["input_path"] = layout_pred["input_path"]
+            single_img_res["layout_result"] = layout_pred
+            single_img = img_info["img"]
+            table_subs = []
+            curve_subs = []
+            formula_subs = []
+            structure_res = []
+            ocr_res_with_layout = []
+            if len(layout_pred["boxes"]) > 0:
+                subs_of_img = list(self._crop_by_boxes(layout_pred))
+                # get cropped images
+                for sub in subs_of_img:
+                    box = sub["box"]
+                    xmin, ymin, xmax, ymax = [int(i) for i in box]
+                    mask_flag = True
+                    if sub["label"].lower() == "table":
+                        table_subs.append(sub)
+                    elif sub["label"].lower() == "seal":
+                        curve_subs.append(sub)
+                    elif sub["label"].lower() == "formula":
+                        formula_subs.append(sub)
+                    else:
+                        if self.recovery and recovery:
+                            # TODO: Why use the entire image?
+                            wht_im = (
+                                np.ones(single_img.shape, dtype=single_img.dtype) * 255
+                            )
+                            wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
+                            sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
+                        else:
+                            sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
+                            sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
+                                xmin, ymin, sub_ocr_res["dt_polys"]
+                            )
+                        layout_label = sub["label"].lower()
+                        # Adapt the user label definition to specify behavior.
+                        if sub_ocr_res and sub["label"].lower() in [
+                            "image",
+                            "figure",
+                            "img",
+                            "fig",
+                        ]:
+                            get_text_in_image = kwargs.get("get_text_in_image", False)
+                            mask_flag = not get_text_in_image
+                            text_in_image = ""
+                            if get_text_in_image:
+                                text_in_image = "".join(sub_ocr_res["rec_text"])
+                                ocr_res_with_layout.append(sub_ocr_res)
+                            structure_res.append(
+                                {
+                                    "input_path": sub_ocr_res["input_path"],
+                                    "layout_bbox": box,
+                                    f"{layout_label}": {
+                                        "img": sub["img"],
+                                        f"{layout_label}_text": text_in_image,
+                                    },
+                                }
+                            )
+                        else:
+                            ocr_res_with_layout.append(sub_ocr_res)
+                            structure_res.append(
+                                {
+                                    "input_path": sub_ocr_res["input_path"],
+                                    "layout_bbox": box,
+                                    f"{layout_label}": "\n".join(
+                                        sub_ocr_res["rec_text"]
+                                    ),
+                                }
+                            )
+                    if mask_flag:
+                        single_img[ymin:ymax, xmin:xmax, :] = 255
+
+            curve_pipeline = self.ocr_pipeline
+            if self.curve_pipeline and use_seal_text_det_model:
+                curve_pipeline = self.curve_pipeline
+
+            all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
+            single_img_res["curve_result"] = all_curve_res
+            if isinstance(all_curve_res, dict):
+                all_curve_res = [all_curve_res]
+            for sub, curve_res in zip(curve_subs, all_curve_res):
+                structure_res.append(
+                    {
+                        "input_path": curve_res["input_path"],
+                        "layout_bbox": sub["box"],
+                        "seal": "".join(curve_res["rec_text"]),
+                    }
+                )
+
+            all_formula_res = get_formula_res(self.formula_predictor, formula_subs)
+            single_img_res["formula_result"] = all_formula_res
+            for sub, formula_res in zip(formula_subs, all_formula_res):
+                structure_res.append(
+                    {
+                        "input_path": formula_res["input_path"],
+                        "layout_bbox": sub["box"],
+                        "formula": "".join(formula_res["rec_text"]),
+                    }
+                )
+
+            use_ocr_without_layout = kwargs.get("use_ocr_without_layout", True)
+            ocr_res = {
+                "dt_polys": [],
+                "rec_text": [],
+                "input_path": layout_pred["input_path"],
+            }
+
+            if use_ocr_without_layout:
+                ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
+                ocr_res["input_path"] = layout_pred["input_path"]
+                for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
+                    structure_res.append(
+                        {
+                            "input_path": ocr_res["input_path"],
+                            "layout_bbox": convert_4point2rect(single_dt_poly),
+                            "text_without_layout": ocr_res["rec_text"][idx],
+                        }
+                    )
+            # update ocr result
+            for layout_ocr_res in ocr_res_with_layout:
+                ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
+                ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
+                ocr_res["rec_score"].extend(layout_ocr_res["rec_score"])
+                ocr_res["input_path"] = single_img_res["input_path"]
+
+            all_table_ocr_res = []
+            all_table_res, _ = self.get_table_result(table_subs)
+            # get table text from html
+            structure_res_table, all_table_ocr_res = get_table_text_from_html(
+                all_table_res
+            )
+            structure_res.extend(structure_res_table)
+
+            # sort the layout result by the left top point of the box
+            structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
+            structure_res = LayoutParsingResult(
+                {
+                    "input_path": layout_pred["input_path"],
+                    "parsing_result": structure_res,
+                }
+            )
+
+            single_img_res["table_result"] = all_table_res
+            single_img_res["ocr_result"] = ocr_res
+            single_img_res["table_ocr_result"] = all_table_ocr_res
+            single_img_res["layout_parsing_result"] = structure_res
+
+            yield VisualResult(single_img_res)
+
+
+def get_formula_res(predictor, input):
+    """get formula res"""
+    res_list = []
+    if isinstance(input, list):
+        img = [im["img"] for im in input]
+    elif isinstance(input, dict):
+        img = input["img"]
+    else:
+        img = input
+    for res in predictor(img):
+        res_list.append(res)
+    return res_list

+ 20 - 11
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -17,13 +17,12 @@ import re
 import json
 import numpy as np
 from .utils import *
+from ...results import *
 from copy import deepcopy
 from ...components import *
 from ..ocr import OCRPipeline
 from ....utils import logging
-from ...results import *
 from ...components.llm import ErnieBot
-from ...utils.io import ImageReader, PDFReader
 from ..table_recognition import _TableRecPipeline
 from ...components.llm import create_llm_api, ErnieBot
 from ....utils.file_interface import read_yaml_file
@@ -362,7 +361,12 @@ class PPChatOCRPipeline(_TableRecPipeline):
 
             # sort the layout result by the left top point of the box
             structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
-            structure_res = [LayoutStructureResult(item) for item in structure_res]
+            structure_res = LayoutParsingResult(
+                {
+                    "input_path": layout_pred["input_path"],
+                    "layout_parsing_result": structure_res,
+                }
+            )
 
             single_img_res["table_result"] = all_table_res
             single_img_res["ocr_result"] = ocr_res
@@ -376,7 +380,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
         table_text_list = []
         table_html = []
         for single_img_pred in visual_result:
-            layout_res = single_img_pred["structure_result"]
+            layout_res = single_img_pred["structure_result"]["layout_parsing_result"]
             layout_res_copy = deepcopy(layout_res)
             # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
             ocr_res = {}
@@ -505,6 +509,11 @@ class PPChatOCRPipeline(_TableRecPipeline):
 
         prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
 
+        if llm_name:
+            llm_api = create_llm_api(llm_name, llm_params)
+        else:
+            llm_api = self.llm_api
+
         final_results = {}
         failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
         if html_list:
@@ -514,7 +523,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
             prompt_res["html_prompt"] = prompt_list
             for prompt, table_text in zip(prompt_list, table_text_list):
                 logging.debug(prompt)
-                res = self.get_llm_result(prompt)
+                res = self.get_llm_result(llm_api, prompt)
                 # TODO: why use one html but the whole table_text in next step
                 if list(res.values())[0] in failed_results:
                     logging.debug(
@@ -525,7 +534,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
                     )
                     logging.debug(prompt)
                     prompt_res["table_prompt"].append(prompt)
-                    res = self.get_llm_result(prompt)
+                    res = self.get_llm_result(llm_api, prompt)
                 for key, value in res.items():
                     if value not in failed_results and key in key_list:
                         key_list.remove(key)
@@ -558,22 +567,22 @@ class PPChatOCRPipeline(_TableRecPipeline):
             )
             logging.debug(prompt)
             prompt_res["ocr_prompt"] = prompt
-            res = self.get_llm_result(prompt)
+            res = self.get_llm_result(llm_api, prompt)
             if res:
                 final_results.update(res)
         if not res and not final_results:
-            final_results = self.llm_api.ERROR_MASSAGE
+            final_results = llm_api.ERROR_MASSAGE
         if save_prompt:
             return ChatResult({"chat_res": final_results, "prompt": prompt_res})
         else:
             return ChatResult({"chat_res": final_results, "prompt": ""})
 
-    def get_llm_result(self, prompt):
+    def get_llm_result(self, llm_api, prompt):
         """get llm result and decode to dict"""
-        llm_result = self.llm_api.pred(prompt)
+        llm_result = llm_api.pred(prompt)
         # when the llm pred failed, return None
         if not llm_result:
-            return None
+            return {}
 
         if "json" in llm_result or "```" in llm_result:
             llm_result = (

+ 2 - 20
paddlex/inference/pipelines/ppchatocrv3/utils.py

@@ -12,10 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import re
-import numpy as np
-from pathlib import Path
 from scipy.ndimage import rotate
 
 
@@ -133,33 +130,21 @@ def sorted_layout_boxes(res, w):
 
     new_res = []
     res_left = []
-    res_mid = []
     res_right = []
     i = 0
 
     while True:
         if i >= num_boxes:
             break
-        # Check if there are three columns of pictures
-        if (
-            _boxes[i]["layout_bbox"][0] > w / 4
-            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 4
-        ):
-            _boxes[i]["layout"] = "double"
-            res_mid.append(_boxes[i])
-            i += 1
         # Check that the bbox is on the left
         elif (
             _boxes[i]["layout_bbox"][0] < w / 4
-            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 5
+            and _boxes[i]["layout_bbox"][2] < 3 * w / 5
         ):
             _boxes[i]["layout"] = "double"
             res_left.append(_boxes[i])
             i += 1
-        elif (
-            _boxes[i]["layout_bbox"][0] > 2 * w / 5
-            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < w
-        ):
+        elif _boxes[i]["layout_bbox"][0] > 2 * w / 5:
             _boxes[i]["layout"] = "double"
             res_right.append(_boxes[i])
             i += 1
@@ -173,13 +158,10 @@ def sorted_layout_boxes(res, w):
             i += 1
 
     res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
-    res_mid = sorted(res_mid, key=lambda x: (x["layout_bbox"][1]))
     res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
 
     if res_left:
         new_res += res_left
-    if res_mid:
-        new_res += res_mid
     if res_right:
         new_res += res_right
 

+ 37 - 34
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -73,40 +73,8 @@ class _TableRecPipeline(BasePipeline):
             self.ocr_pipeline.text_rec_model.set_predictor(device=device)
             self.table_predictor.set_predictor(device=device)
 
-    def predict(self, input, **kwargs):
-        self.set_predictor(**kwargs)
-        for layout_pred, ocr_pred in zip(
-            self.layout_predictor(input), self.ocr_pipeline(input)
-        ):
-            single_img_res = {
-                "input_path": "",
-                "layout_result": {},
-                "ocr_result": {},
-                "table_result": [],
-            }
-            # update layout result
-            single_img_res["input_path"] = layout_pred["input_path"]
-            single_img_res["layout_result"] = layout_pred
-            ocr_res = ocr_pred
-            table_subs = []
-            if len(layout_pred["boxes"]) > 0:
-                subs_of_img = list(self._crop_by_boxes(layout_pred))
-                # get cropped images with label "table"
-                for sub in subs_of_img:
-                    box = sub["box"]
-                    if sub["label"].lower() == "table":
-                        table_subs.append(sub)
-                        _, ocr_res = self.get_related_ocr_result(box, ocr_res)
-            table_res, all_table_ocr_res = self.get_table_result(table_subs)
-            for table_ocr_res in all_table_ocr_res:
-                ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
-                ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
-                ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
-
-            single_img_res["table_result"] = table_res
-            single_img_res["ocr_result"] = OCRResult(ocr_res)
-
-            yield TableResult(single_img_res)
+    def predict(self, inputs):
+        raise NotImplementedError("The method `predict` has not been implemented yet.")
 
     def get_related_ocr_result(self, box, ocr_res):
         dt_polys_list = []
@@ -188,3 +156,38 @@ class TableRecPipeline(_TableRecPipeline):
             text_rec_batch_size=text_rec_batch_size,
             table_batch_size=table_batch_size,
         )
+
+    def predict(self, input, **kwargs):
+        self.set_predictor(**kwargs)
+        for layout_pred, ocr_pred in zip(
+            self.layout_predictor(input), self.ocr_pipeline(input)
+        ):
+            single_img_res = {
+                "input_path": "",
+                "layout_result": {},
+                "ocr_result": {},
+                "table_result": [],
+            }
+            # update layout result
+            single_img_res["input_path"] = layout_pred["input_path"]
+            single_img_res["layout_result"] = layout_pred
+            ocr_res = ocr_pred
+            table_subs = []
+            if len(layout_pred["boxes"]) > 0:
+                subs_of_img = list(self._crop_by_boxes(layout_pred))
+                # get cropped images with label "table"
+                for sub in subs_of_img:
+                    box = sub["box"]
+                    if sub["label"].lower() == "table":
+                        table_subs.append(sub)
+                        _, ocr_res = self.get_related_ocr_result(box, ocr_res)
+            table_res, all_table_ocr_res = self.get_table_result(table_subs)
+            for table_ocr_res in all_table_ocr_res:
+                ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
+                ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
+                ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
+
+            single_img_res["table_result"] = table_res
+            single_img_res["ocr_result"] = OCRResult(ocr_res)
+
+            yield TableResult(single_img_res)

+ 1 - 1
paddlex/inference/pipelines/table_recognition/utils.py

@@ -284,7 +284,7 @@ def convert_4point2rect(bbox):
     y1 = min(bbox[:, 1])
     x2 = max(bbox[:, 0])
     y2 = max(bbox[:, 1])
-    return [x1, y1, x2, y2]
+    return np.array([x1, y1, x2, y2], dtype=np.float32)
 
 
 def get_ori_coordinate_for_table(x, y, table_bbox):

+ 12 - 8
paddlex/inference/results/chat_ocr.py

@@ -12,13 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import copy
 from pathlib import Path
 from .base import BaseResult
 from .utils.mixin import Base64Mixin
 
 
-class LayoutStructureResult(BaseResult):
-    """LayoutStructureResult"""
+class LayoutParsingResult(BaseResult):
+    """LayoutParsingResult"""
 
     pass
 
@@ -32,6 +33,9 @@ class VisualInfoResult(BaseResult):
 class VisualResult(BaseResult):
     """VisualInfoResult"""
 
+    def _to_str(self):
+        return str({"layout_parsing_result": self["layout_parsing_result"]})
+
     def save_to_html(self, save_path):
         if not save_path.lower().endswith(("html")):
             input_path = self["input_path"]
@@ -67,28 +71,28 @@ class VisualResult(BaseResult):
         if unwarp_result:
             # unwarp_result._HARD_FLAG = True
             unwarp_result.save_to_img(uvdoc_save_path)
-        curve_save_path = f"{save_path}_curve.jpg"
+        curve_save_path = f"{save_path}_curve"
         curve_results = self["curve_result"]
         # TODO(): support list of result
         if isinstance(curve_results, dict):
             curve_results = [curve_results]
-        for curve_result in curve_results:
+        for idx, curve_result in enumerate(curve_results):
             curve_result._HARD_FLAG = True if not unwarp_result else False
-            curve_result.save_to_img(curve_save_path)
+            curve_result.save_to_img(f"{curve_save_path}_{idx}.jpg")
         layout_save_path = f"{save_path}_layout.jpg"
         layout_result = self["layout_result"]
         if layout_result:
             layout_result._HARD_FLAG = True if not unwarp_result else False
             layout_result.save_to_img(layout_save_path)
         ocr_save_path = f"{save_path}_ocr.jpg"
-        table_save_path = f"{save_path}_table.jpg"
+        table_save_path = f"{save_path}_table"
         ocr_result = self["ocr_result"]
         if ocr_result:
             ocr_result._HARD_FLAG = True if not unwarp_result else False
             ocr_result.save_to_img(ocr_save_path)
-        for table_result in self["table_result"]:
+        for idx, table_result in enumerate(self["table_result"]):
             table_result._HARD_FLAG = True if not unwarp_result else False
-            table_result.save_to_img(table_save_path)
+            table_result.save_to_img(f"{table_save_path}_{idx}.jpg")
 
 
 class VectorResult(BaseResult, Base64Mixin):

+ 14 - 0
paddlex/pipelines/layout_parsing.yaml

@@ -0,0 +1,14 @@
+Global:
+  pipeline_name: layout_parsing
+  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
+  # input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/PP-OCRv3.pdf
+  
+Pipeline:
+  layout_model: RT-DETR-H_layout_17cls
+  table_model: SLANet_plus
+  formula_rec_model: LaTeX_OCR_rec
+  text_det_model: PP-OCRv4_server_det
+  text_rec_model: PP-OCRv4_server_rec
+  seal_text_det_model: PP-OCRv4_server_seal_det
+  doc_image_unwarp_model: None
+  doc_image_ori_cls_model: None

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików