Browse Source

update markdown format & table_recognition_v1 add cell sorting (#2929)

* support inline formula embedding & update reference format

* update markdown format and table_recognition add cell sorting

* layout_parsing_v2 save_to_img remove formula_res & update table match

* layout_parsing_v2 add detection param
shuai.liu 10 months ago
parent
commit
f0d163cb01

+ 40 - 35
paddlex/configs/pipelines/layout_parsing_v2.yaml

@@ -12,6 +12,11 @@ SubModules:
     module_name: layout_detection
     model_name: PP-DocLayout-L
     model_dir: null
+    threshold: 
+      7: 0.3
+    layout_nms: True
+    layout_unclip_ratio: 1.0
+    layout_merge_bboxes_mode: "large"
 
 SubPipelines:
   DocPreprocessor:
@@ -51,47 +56,47 @@ SubPipelines:
         batch_size: 1
         score_thresh: 0.0
 
-  # TableRecognition:
-  #   pipeline_name: table_recognition_v2
-  #   use_layout_detection: False
-  #   use_doc_preprocessor: False
-  #   use_ocr_model: True
-  #   SubModules:  
-  #     TableClassification:
-  #       module_name: table_classification
-  #       model_name: PP-LCNet_x1_0_table_cls
-  #       model_dir: null
-
-  #     WiredTableStructureRecognition:
-  #       module_name: table_structure_recognition
-  #       model_name: SLANeXt_wired
-  #       model_dir: null
-      
-  #     WirelessTableStructureRecognition:
-  #       module_name: table_structure_recognition
-  #       model_name: SLANeXt_wireless
-  #       model_dir: null
-      
-  #     WiredTableCellsDetection:
-  #       module_name: table_cells_detection
-  #       model_name: RT-DETR-L_wired_table_cell_det
-  #       model_dir: null
-      
-  #     WirelessTableCellsDetection:
-  #       module_name: table_cells_detection
-  #       model_name: RT-DETR-L_wireless_table_cell_det
-  #       model_dir: null
-
   TableRecognition:
-    pipeline_name: table_recognition
+    pipeline_name: table_recognition_v2
     use_layout_detection: False
     use_doc_preprocessor: False
     use_ocr_model: False
-    SubModules:
-      TableStructureRecognition:
+    SubModules:  
+      TableClassification:
+        module_name: table_classification
+        model_name: PP-LCNet_x1_0_table_cls
+        model_dir: null
+
+      WiredTableStructureRecognition:
+        module_name: table_structure_recognition
+        model_name: SLANeXt_wired
+        model_dir: null
+      
+      WirelessTableStructureRecognition:
         module_name: table_structure_recognition
-        model_name: SLANet_plus
+        model_name: SLANeXt_wireless
+        model_dir: null
+      
+      WiredTableCellsDetection:
+        module_name: table_cells_detection
+        model_name: RT-DETR-L_wired_table_cell_det
         model_dir: null
+      
+      WirelessTableCellsDetection:
+        module_name: table_cells_detection
+        model_name: RT-DETR-L_wireless_table_cell_det
+        model_dir: null
+
+  # TableRecognition:
+  #   pipeline_name: table_recognition
+  #   use_layout_detection: False
+  #   use_doc_preprocessor: False
+  #   use_ocr_model: False
+  #   SubModules:
+  #     TableStructureRecognition:
+  #       module_name: table_structure_recognition
+  #       model_name: SLANet_plus
+  #       model_dir: null
 
   SealRecognition:
     pipeline_name: seal_recognition

+ 17 - 1
paddlex/inference/pipelines_new/layout_parsing/pipeline_v2.py

@@ -102,7 +102,22 @@ class LayoutParsingPipelineV2(BasePipeline):
             "LayoutDetection",
             {"model_config_error": "config error for layout_det_model!"},
         )
-        self.layout_det_model = self.create_model(layout_det_config)
+        layout_kwargs = {}
+        if (threshold := layout_det_config.get("threshold", None)) is not None:
+            layout_kwargs["threshold"] = threshold
+        if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
+            layout_kwargs["layout_nms"] = layout_nms
+        if (
+            layout_unclip_ratio := layout_det_config.get("layout_unclip_ratio", None)
+        ) is not None:
+            layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
+        if (
+            layout_merge_bboxes_mode := layout_det_config.get(
+                "layout_merge_bboxes_mode", None
+            )
+        ) is not None:
+            layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
+        self.layout_det_model = self.create_model(layout_det_config, **layout_kwargs)
 
         if self.use_general_ocr or self.use_table_recognition:
             general_ocr_config = config.get("SubPipelines", {}).get(
@@ -400,6 +415,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                         use_ocr_model=False,
                         overall_ocr_res=overall_ocr_res,
                         layout_det_res=layout_det_res,
+                        cell_sort_by_y_projection=True,
                     ),
                 )
                 table_res_list = table_res_all["table_res_list"]

+ 21 - 29
paddlex/inference/pipelines_new/layout_parsing/result_v2.py

@@ -94,16 +94,16 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 key = f"seal_res_region{seal_region_id}"
                 res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
 
-        if (
-            model_settings["use_formula_recognition"]
-            and len(self["formula_res_list"]) > 0
-        ):
-            for sno in range(len(self["formula_res_list"])):
-                formula_res = self["formula_res_list"][sno]
-                formula_region_id = formula_res["formula_region_id"]
-                sub_formula_res_dict = formula_res.img
-                key = f"formula_res_region{formula_region_id}"
-                res_img_dict[key] = sub_formula_res_dict["res"]
+        # if (
+        #     model_settings["use_formula_recognition"]
+        #     and len(self["formula_res_list"]) > 0
+        # ):
+        #     for sno in range(len(self["formula_res_list"])):
+        #         formula_res = self["formula_res_list"][sno]
+        #         formula_region_id = formula_res["formula_region_id"]
+        #         sub_formula_res_dict = formula_res.img
+        #         key = f"formula_res_region{formula_region_id}"
+        #         res_img_dict[key] = sub_formula_res_dict["res"]
 
         return res_img_dict
 
@@ -415,31 +415,23 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 "table_title": lambda: format_centered_text("table_title"),
                 "figure_title": lambda: format_centered_text("figure_title"),
                 "chart_title": lambda: format_centered_text("chart_title"),
-                "text": lambda: sub_block["text"].strip("\n"),
+                "text": lambda: sub_block["text"]
+                .replace("-\n", " ")
+                .replace("\n", " "),
                 # 'number': lambda: str(sub_block['number']),
-                "abstract": lambda: "\n" + sub_block["abstract"].strip("\n"),
+                "abstract": lambda: sub_block["abstract"]
+                .replace("-\n", " ")
+                .replace("\n", " "),
                 "content": lambda: sub_block["content"]
-                .replace("-\n", "")
-                .replace("\n", " ")
-                .strip(),
+                .replace("-\n", " ")
+                .replace("\n", " "),
                 "image": format_image,
                 "chart": format_chart,
-                "formula": lambda: f"$${sub_block['formula']}$$".replace(
-                    "-\n",
-                    "",
-                ).replace("\n", " "),
+                "formula": lambda: f"$${sub_block['formula']}$$",
                 "table": format_table,
                 "reference": format_reference,
-                "algorithm": lambda: "\n"
-                + f"**Algorithm**: {sub_block['algorithm']}".replace("-\n", "").replace(
-                    "\n",
-                    " ",
-                ),
-                "seal": lambda: "\n"
-                + f"**Seal**: {sub_block['seal']}".replace("-\n", "").replace(
-                    "\n",
-                    " ",
-                ),
+                "algorithm": lambda: sub_block["algorithm"].strip("\n"),
+                "seal": lambda: sub_block["seal"].strip("\n"),
             }
             parsing_result = obj["layout_parsing_result"]
             markdown_content = ""

+ 75 - 38
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -106,7 +106,7 @@ def get_sub_regions_ocr_res(
     return sub_regions_ocr_res
 
 
-def calculate_iou(box1, box2):
+def _calculate_iou(box1, box2):
     """
     Calculate Intersection over Union (IoU) between two bounding boxes.
 
@@ -142,39 +142,69 @@ def calculate_iou(box1, box2):
     return iou
 
 
-def _whether_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
-    _, y0_1, _, y1_1 = bbox1
-    _, y0_2, _, y1_2 = bbox2
+def _whether_y_overlap_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
+    """
+    Determines whether the vertical overlap between two bounding boxes exceeds a given threshold.
+
+    Args:
+        bbox1 (tuple): The first bounding box defined as (left, top, right, bottom).
+        bbox2 (tuple): The second bounding box defined as (left, top, right, bottom).
+        overlap_ratio_threshold (float): The threshold ratio to determine if the overlap is significant.
+                                         Defaults to 0.6.
 
-    overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
-    min_height = min(y1_1 - y0_1, y1_2 - y0_2)
+    Returns:
+        bool: True if the vertical overlap divided by the minimum height of the two bounding boxes
+              exceeds the overlap_ratio_threshold, otherwise False.
+    """
+    _, y1_0, _, y1_1 = bbox1
+    _, y2_0, _, y2_1 = bbox2
+
+    overlap = max(0, min(y1_1, y2_1) - max(y1_0, y2_0))
+    min_height = min(y1_1 - y1_0, y2_1 - y2_0)
 
     return (overlap / min_height) > overlap_ratio_threshold
 
 
-def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
-    assert ocr_res["boxes"] and ocr_res["rec_texts"]
+def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_iou_threshold=0.7):
+    """
+    Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
+
+    Args:
+        layout_bbox (tuple): A tuple representing the layout bounding box, defined as (left, top, right, bottom).
+        ocr_res (dict): A dictionary containing OCR results with the following keys:
+                        - "boxes": A list of bounding boxes, each defined as [left, top, right, bottom].
+                        - "rec_texts": A corresponding list of recognized text strings for each box.
+        line_height_iou_threshold (float): The threshold for determining whether two boxes belong to
+                                           the same line based on their vertical overlap. Defaults to 0.7.
+
+    Returns:
+        dict: A dictionary with the same structure as `ocr_res`, but with boxes and texts sorted
+              and grouped into lines and blocks.
+    """
+    assert (
+        ocr_res["boxes"] and ocr_res["rec_texts"]
+    ), "OCR results must contain 'boxes' and 'rec_texts'"
 
-    # span->line->block
     boxes = ocr_res["boxes"]
-    rec_text = ocr_res["rec_texts"]
-    x_min, x_max = layout_bbox[0], layout_bbox[2]
+    rec_texts = ocr_res["rec_texts"]
+
+    x_min, _, x_max, _ = layout_bbox
+
+    spans = list(zip(boxes, rec_texts))
 
-    spans = list(zip(boxes, rec_text))
     spans.sort(key=lambda span: span[0][1])
     spans = [list(span) for span in spans]
 
     lines = []
-    first_span = spans[0]
-    current_line = [first_span]
-    current_y0, current_y1 = first_span[0][1], first_span[0][3]
+    current_line = [spans[0]]
+    current_y0, current_y1 = spans[0][0][1], spans[0][0][3]
 
     for span in spans[1:]:
         y0, y1 = span[0][1], span[0][3]
-        if _whether_overlaps_y_exceeds_threshold(
+        if _whether_y_overlap_exceeds_threshold(
             (0, current_y0, 0, current_y1),
             (0, y0, 0, y1),
-            line_height_threshold,
+            line_height_iou_threshold,
         ):
             current_line.append(span)
             current_y0 = min(current_y0, y0)
@@ -191,13 +221,16 @@ def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
         line.sort(key=lambda span: span[0][0])
         first_span = line[0]
         end_span = line[-1]
-        if first_span[0][0] - x_min > 20:
+
+        if first_span[0][0] - x_min > 15:
             first_span[1] = "\n" + first_span[1]
-        if x_max - end_span[0][2] > 20:
+        if x_max - end_span[0][2] > 15:
             end_span[1] = end_span[1] + "\n"
 
+    # Flatten lines back into a single list for boxes and texts
     ocr_res["boxes"] = [span[0] for line in lines for span in line]
     ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
+
     return ocr_res
 
 
@@ -241,7 +274,12 @@ def get_structure_res(
 
         if label == "table":
             for i, table_res in enumerate(table_res_list):
-                if calculate_iou(layout_bbox, table_res["cell_box_list"][0]) > 0.5:
+                if (
+                    _calculate_iou(
+                        layout_bbox, table_res["table_ocr_pred"]["rec_boxes"][0]
+                    )
+                    > 0.5
+                ):
                     structure_boxes.append(
                         {
                             "label": label,
@@ -256,7 +294,7 @@ def get_structure_res(
         else:
             overall_text_boxes = overall_ocr_res["rec_boxes"]
             for box_no in range(len(overall_text_boxes)):
-                if calculate_iou(layout_bbox, overall_text_boxes[box_no]) > 0.5:
+                if _calculate_iou(layout_bbox, overall_text_boxes[box_no]) > 0.5:
                     rec_res["boxes"].append(overall_text_boxes[box_no])
                     rec_res["rec_texts"].append(
                         overall_ocr_res["rec_texts"][box_no],
@@ -306,7 +344,7 @@ def get_structure_res(
     return structure_boxes
 
 
-def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
+def _projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
     """
     Generate a 1D projection histogram from bounding boxes along a specified axis.
 
@@ -328,7 +366,7 @@ def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
     return projection
 
 
-def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
+def _split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
     """
     Split the projection profile into segments based on specified thresholds.
 
@@ -363,7 +401,7 @@ def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap:
     return segment_starts, segment_ends
 
 
-def recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
+def _recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
     """
     Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
 
@@ -380,8 +418,8 @@ def recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
     y_sorted_indices = np.array(indices)[y_sorted_indices]
 
     # Perform Y-axis projection
-    y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
-    y_intervals = split_projection_profile(y_projection, 0, 1)
+    y_projection = _projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
+    y_intervals = _split_projection_profile(y_projection, 0, 1)
 
     if not y_intervals:
         return
@@ -401,8 +439,8 @@ def recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
         x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
 
         # Perform X-axis projection
-        x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
-        x_intervals = split_projection_profile(x_projection, 0, min_gap)
+        x_projection = _projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
+        x_intervals = _split_projection_profile(x_projection, 0, min_gap)
 
         if not x_intervals:
             continue
@@ -417,14 +455,14 @@ def recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
             x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
                 x_sorted_boxes_chunk[:, 0] < x_end
             )
-            recursive_yx_cut(
+            _recursive_yx_cut(
                 x_sorted_boxes_chunk[x_interval_indices],
                 x_sorted_indices_chunk[x_interval_indices],
                 res,
             )
 
 
-def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
+def _recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
     """
     Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
 
@@ -442,8 +480,8 @@ def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
     x_sorted_indices = np.array(indices)[x_sorted_indices]
 
     # Perform X-axis projection
-    x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
-    x_intervals = split_projection_profile(x_projection, 0, 1)
+    x_projection = _projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
+    x_intervals = _split_projection_profile(x_projection, 0, 1)
 
     if not x_intervals:
         return
@@ -463,8 +501,8 @@ def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
         y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
 
         # Perform Y-axis projection
-        y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
-        y_intervals = split_projection_profile(y_projection, 0, min_gap)
+        y_projection = _projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
+        y_intervals = _split_projection_profile(y_projection, 0, min_gap)
 
         if not y_intervals:
             continue
@@ -479,7 +517,7 @@ def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_
             y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
                 y_sorted_boxes_chunk[:, 1] < y_end
             )
-            recursive_xy_cut(
+            _recursive_xy_cut(
                 y_sorted_boxes_chunk[y_interval_indices],
                 y_sorted_indices_chunk[y_interval_indices],
                 res,
@@ -490,7 +528,7 @@ def sort_by_xycut(block_bboxes, direction=0, min_gap=1):
     block_bboxes = np.asarray(block_bboxes).astype(int)
     res = []
     if direction == 1:
-        recursive_yx_cut(
+        _recursive_yx_cut(
             block_bboxes,
             np.arange(
                 len(block_bboxes),
@@ -499,7 +537,7 @@ def sort_by_xycut(block_bboxes, direction=0, min_gap=1):
             min_gap,
         )
     else:
-        recursive_xy_cut(
+        _recursive_xy_cut(
             block_bboxes,
             np.arange(
                 len(block_bboxes),
@@ -1125,7 +1163,6 @@ def get_layout_ordering(data, no_mask_labels=[], already_sorted=False):
                 ),
             )
             block_bboxes = np.array(block_bboxes)
-            print("sort by yxcut...")
             sorted_indices = sort_by_xycut(
                 block_bboxes,
                 direction=1,

+ 15 - 2
paddlex/inference/pipelines_new/table_recognition/pipeline.py

@@ -224,6 +224,7 @@ class TableRecognitionPipeline(BasePipeline):
         overall_ocr_res: OCRResult,
         table_box: list,
         flag_find_nei_text: bool = True,
+        cell_sort_by_y_projection: bool = False,
     ) -> SingleTableRecognitionResult:
         """
         Predict table recognition results from an image array, layout detection results, and OCR results.
@@ -234,12 +235,16 @@ class TableRecognitionPipeline(BasePipeline):
                 The overall OCR results containing text recognition information.
             table_box (list): The table box coordinates.
             flag_find_nei_text (bool): Whether to find neighboring text.
+            cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
         Returns:
             SingleTableRecognitionResult: single table recognition result.
         """
         table_structure_pred = next(self.table_structure_model(image_array))
         single_table_recognition_res = get_table_recognition_res(
-            table_box, table_structure_pred, overall_ocr_res
+            table_box,
+            table_structure_pred,
+            overall_ocr_res,
+            cell_sort_by_y_projection=cell_sort_by_y_projection,
         )
         neighbor_text = ""
         if flag_find_nei_text:
@@ -267,6 +272,7 @@ class TableRecognitionPipeline(BasePipeline):
         text_det_box_thresh: Optional[float] = None,
         text_det_unclip_ratio: Optional[float] = None,
         text_rec_score_thresh: Optional[float] = None,
+        cell_sort_by_y_projection: Optional[bool] = None,
         **kwargs,
     ) -> TableRecognitionResult:
         """
@@ -281,6 +287,7 @@ class TableRecognitionPipeline(BasePipeline):
                 It will be used if it is not None and use_ocr_model is False.
             layout_det_res (DetResult): The layout detection result.
                 It will be used if it is not None and use_layout_detection is False.
+            cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
             **kwargs: Additional keyword arguments.
 
         Returns:
@@ -293,6 +300,8 @@ class TableRecognitionPipeline(BasePipeline):
             use_layout_detection,
             use_ocr_model,
         )
+        if cell_sort_by_y_projection is None:
+            cell_sort_by_y_projection = False
 
         if not self.check_model_settings_valid(
             model_settings, overall_ocr_res, layout_det_res
@@ -339,6 +348,7 @@ class TableRecognitionPipeline(BasePipeline):
                     overall_ocr_res,
                     table_box,
                     flag_find_nei_text=False,
+                    cell_sort_by_y_projection=cell_sort_by_y_projection,
                 )
                 single_table_rec_res["table_region_id"] = table_region_id
                 table_res_list.append(single_table_rec_res)
@@ -354,7 +364,10 @@ class TableRecognitionPipeline(BasePipeline):
                         table_box = crop_img_info["box"]
                         single_table_rec_res = (
                             self.predict_single_table_recognition_res(
-                                crop_img_info["img"], overall_ocr_res, table_box
+                                crop_img_info["img"],
+                                overall_ocr_res,
+                                table_box,
+                                cell_sort_by_y_projection=cell_sort_by_y_projection,
                             )
                         )
                         single_table_rec_res["table_region_id"] = table_region_id

+ 95 - 3
paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py

@@ -118,13 +118,91 @@ def compute_iou(rec1: list, rec2: list) -> float:
         return (intersect / (sum_area - intersect)) * 1.0
 
 
-def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
+def _whether_y_overlap_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
+    """
+    Determines whether the vertical overlap between two bounding boxes exceeds a given threshold.
+
+    Args:
+        bbox1 (tuple): The first bounding box defined as (left, top, right, bottom).
+        bbox2 (tuple): The second bounding box defined as (left, top, right, bottom).
+        overlap_ratio_threshold (float): The threshold ratio to determine if the overlap is significant.
+                                         Defaults to 0.6.
+
+    Returns:
+        bool: True if the vertical overlap divided by the minimum height of the two bounding boxes
+              exceeds the overlap_ratio_threshold, otherwise False.
+    """
+    _, y1_0, _, y1_1 = bbox1
+    _, y2_0, _, y2_1 = bbox2
+
+    overlap = max(0, min(y1_1, y2_1) - max(y1_0, y2_0))
+    min_height = min(y1_1 - y1_0, y2_1 - y2_0)
+
+    return (overlap / min_height) > overlap_ratio_threshold
+
+
+def _sort_box_by_y_projection(boxes, line_height_iou_threshold=0.6):
+    """
+    Sorts a list of bounding boxes based on their spatial arrangement.
+
+    The function first sorts the boxes by their top y-coordinate to group them into lines.
+    Within each line, the boxes are then sorted by their x-coordinate.
+
+    Args:
+        boxes (list): A list of bounding boxes, where each box is defined as [left, top, right, bottom].
+        line_height_iou_threshold (float): The Intersection over Union (IoU) threshold for grouping boxes into the same line.
+
+    Returns:
+        list: A list of indices representing the order of the boxes after sorting by their spatial arrangement.
+    """
+
+    if not boxes:
+        return []
+
+    indexed_boxes = list(enumerate(boxes))
+    indexed_boxes.sort(key=lambda item: item[1][1])
+
+    lines = []
+    first_index, first_box = indexed_boxes[0]
+    current_line = [(first_index, first_box)]
+    current_y0, current_y1 = first_box[1], first_box[3]
+
+    for index, box in indexed_boxes[1:]:
+        y0, y1 = box[1], box[3]
+        if _whether_y_overlap_exceeds_threshold(
+            (0, current_y0, 0, current_y1),
+            (0, y0, 0, y1),
+            line_height_iou_threshold,
+        ):
+            current_line.append((index, box))
+            current_y0 = min(current_y0, y0)
+            current_y1 = max(current_y1, y1)
+        else:
+            lines.append(current_line)
+            current_line = [(index, box)]
+            current_y0, current_y1 = y0, y1
+
+    if current_line:
+        lines.append(current_line)
+
+    for line in lines:
+        line.sort(key=lambda item: item[1][0])
+
+    sorted_indices = [index for line in lines for index, _ in line]
+
+    return sorted_indices
+
+
+def match_table_and_ocr(
+    cell_box_list: list, ocr_dt_boxes: list, cell_sort_by_y_projection: bool = False
+) -> dict:
     """
     match table and ocr
 
     Args:
         cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
         ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
+        cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
 
     Returns:
         dict: matched dict, key is table index, value is ocr index
@@ -144,6 +222,14 @@ def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
             matched[distances.index(sorted_distances[0])] = [i]
         else:
             matched[distances.index(sorted_distances[0])].append(i)
+
+    if cell_sort_by_y_projection:
+        for cell_index in matched:
+            input_boxes = [ocr_dt_boxes[i] for i in matched[cell_index]]
+            sorted_indices = _sort_box_by_y_projection(input_boxes, 0.7)
+            sorted_indices = [matched[cell_index][i] for i in sorted_indices]
+            matched[cell_index] = sorted_indices
+
     return matched
 
 
@@ -210,7 +296,10 @@ def get_html_result(
 
 
 def get_table_recognition_res(
-    table_box: list, table_structure_pred: dict, overall_ocr_res: OCRResult
+    table_box: list,
+    table_structure_pred: dict,
+    overall_ocr_res: OCRResult,
+    cell_sort_by_y_projection: bool = False,
 ) -> SingleTableRecognitionResult:
     """
     Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
@@ -219,6 +308,7 @@ def get_table_recognition_res(
         table_box (list): Information about the location of cropped image, including the bounding box.
         table_structure_pred (dict): Predicted table structure.
         overall_ocr_res (OCRResult): Overall OCR result from the input image.
+        cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
 
     Returns:
         SingleTableRecognitionResult: An object containing the single table recognition result.
@@ -237,7 +327,9 @@ def get_table_recognition_res(
     ocr_dt_boxes = table_ocr_pred["rec_boxes"]
     ocr_texts_res = table_ocr_pred["rec_texts"]
 
-    matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
+    matched_index = match_table_and_ocr(
+        cell_box_list, ocr_dt_boxes, cell_sort_by_y_projection=cell_sort_by_y_projection
+    )
     pred_html = get_html_result(matched_index, ocr_texts_res, structures)
 
     single_img_res = {