liuhongen1234567 1 год назад
Родитель
Сommit
6993cb08dd

Разница между файлами не показана из-за своего большого размера
+ 84 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/formula_recognition.md


Разница между файлами не показана из-за своего большого размера
+ 83 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/formula_recognition_en.md


+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -32,6 +32,7 @@ from .single_model_pipeline import (
     AnomalyDetection,
 )
 from .ocr import OCRPipeline
+from .formula_recognition import FormulaRecognitionPipeline
 from .table_recognition import TableRecPipeline
 from .seal_recognition import SealOCRPipeline
 from .ppchatocrv3 import PPChatOCRPipeline

+ 88 - 0
paddlex/inference/pipelines/formula_recognition.py

@@ -0,0 +1,88 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from ..components import CropByBoxes
+from ..results import FormulaRecResult
+from .base import BasePipeline
+from ...utils import logging
+
+
+class FormulaRecognitionPipeline(BasePipeline):
+    """Formula Recognition Pipeline"""
+
+    entities = "formula_recognition"
+
+    def __init__(
+        self,
+        layout_model,
+        formula_rec_model,
+        layout_batch_size=1,
+        formula_rec_batch_size=1,
+        predictor_kwargs=None,
+    ):
+        super().__init__(predictor_kwargs=predictor_kwargs)
+        self._build_predictor(layout_model, formula_rec_model)
+        self.set_predictor(layout_batch_size, formula_rec_batch_size)
+
+    def _build_predictor(self, layout_model, formula_rec_model):
+        self.layout_predictor = self._create_model(layout_model)
+        self.formula_predictor = self._create_model(formula_rec_model)
+        self._crop_by_boxes = CropByBoxes()
+
+    def set_predictor(self, layout_batch_size=None, formula_rec_batch_size=None):
+        if layout_batch_size:
+            self.layout_predictor.set_predictor(batch_size=layout_batch_size)
+        if formula_rec_batch_size:
+            self.formula_predictor.set_predictor(batch_size=formula_rec_batch_size)
+
+    def predict(self, x, **kwargs):
+        device = kwargs.get("device", None)
+        for layout_pred in self.layout_predictor(x):
+            single_img_res = {
+                "input_path": "",
+                "layout_result": {},
+                "ocr_result": {},
+                "table_result": [],
+            }
+            # update layout result
+            single_img_res["input_path"] = layout_pred["input_path"]
+            single_img_res["layout_result"] = layout_pred
+            single_img_res["dt_polys"] = []
+            single_img_res["rec_formula"] = []
+            all_subs_of_formula_img = []
+            layout_pred["boxes"] = sorted(layout_pred["boxes"], key=lambda x : self.sorted_formula_box(x))
+            if len(layout_pred["boxes"]) > 0:
+                subs_of_img = list(self._crop_by_boxes(layout_pred))
+                # get cropped images with label "formula"
+                for sub in subs_of_img:
+                    if sub["label"].lower() == "formula":
+                        boxes = sub["box"]
+                        x1, y1, x2, y2 = list(boxes)
+                        poly = np.array([[x1, y1],[ x2, y1], [x2, y2], [x1, y2]])
+                        all_subs_of_formula_img.append(sub["img"])
+                        single_img_res["dt_polys"].append(poly)
+                if len(all_subs_of_formula_img)>0:
+                    for formula_res in self.formula_predictor(
+                        all_subs_of_formula_img,
+                        batch_size=kwargs.get("formula_rec_batch_size", 1),
+                        device=device,
+                    ):
+                        single_img_res["rec_formula"].append(str(formula_res["rec_text"]))
+            yield FormulaRecResult(single_img_res)
+
+    def sorted_formula_box(self, x):
+        coordinate = x["coordinate"]
+        x1, y1, x2, y2 = list(coordinate)
+        return (y1+y2)/2

+ 1 - 0
paddlex/inference/results/__init__.py

@@ -21,6 +21,7 @@ from .seal_rec import SealOCRResult
 from .ocr import OCRResult
 from .det import DetResult
 from .seg import SegResult
+from .formula_rec import FormulaRecResult
 from .instance_seg import InstanceSegResult
 from .ts import TSFcResult, TSAdResult, TSClsResult
 from .warp import DocTrResult

+ 116 - 0
paddlex/inference/results/formula_rec.py

@@ -0,0 +1,116 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
+from .base import CVResult
+
+
+class FormulaRecResult(CVResult):
+    _HARD_FLAG = False
+
+    def _to_str(self):
+        rec_formula_str = ", ".join([str(formula) for formula in self['rec_formula']])
+        return str(self).replace("\\\\","\\")
+
+
+    def get_minarea_rect(self, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = np.array(
+            [points[index_a], points[index_b], points[index_c], points[index_d]]
+        ).astype(np.int32)
+
+        return box
+    
+   
+    def _to_img(
+        self,
+    ):
+        """draw ocr result"""
+        # TODO(gaotingquan): mv to postprocess
+        drop_score = 0.5
+
+        boxes = self["dt_polys"]
+        formula = self["rec_formula"]
+        image = self._img_reader.read(self["input_path"])
+        if self._HARD_FLAG:
+            image_np = np.array(image)
+            image = Image.fromarray(image_np[:, :, ::-1])
+        h, w = image.height, image.width
+        img_left = image.copy()
+        random.seed(0)
+        draw_left = ImageDraw.Draw(img_left)
+        if formula is None or len(formula) != len(boxes):
+            formula = [None] * len(boxes)
+        for idx, (box, txt) in enumerate(zip(boxes, formula)):
+            try:
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                    box[:2, 1] = np.mean(box[:, 1])
+                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+            except:
+                continue
+
+        img_left = Image.blend(image, img_left, 0.5)
+        img_show = Image.new("RGB", (w, h), (255, 255, 255))
+        img_show.paste(img_left, (0, 0, w, h))
+        return img_show
+
+
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font

+ 8 - 0
paddlex/pipelines/formula_recognition.yaml

@@ -0,0 +1,8 @@
+Global:
+  pipeline_name: formula_recognition
+  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png
+  
+Pipeline:
+  layout_model: RT-DETR-H_layout_17cls
+  formula_rec_model: LaTeX_OCR_rec
+  formula_rec_batch_size: 5

Некоторые файлы не были показаны из-за большого количества измененных файлов