Browse Source

Add new table related algorithm (#3229)

* fix bugs

* refine codes

* fix bugs

* refine code

* add new algorithm

* refine codes

* refine codes
Liu Jiaxuan 9 months ago
parent
commit
e8d2d34824

+ 1 - 0
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -15,6 +15,7 @@ from __future__ import annotations
 
 
 import copy
 import copy
 from pathlib import Path
 from pathlib import Path
+from PIL import Image, ImageDraw
 from typing import Dict
 from typing import Dict
 
 
 import cv2
 import cv2

+ 235 - 6
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -16,6 +16,7 @@ import os, sys
 from typing import Any, Dict, Optional, Union, List, Tuple
 from typing import Any, Dict, Optional, Union, List, Tuple
 import numpy as np
 import numpy as np
 import cv2
 import cv2
+from sklearn.cluster import KMeans
 from ..base import BasePipeline
 from ..base import BasePipeline
 from ..components import CropByBoxes
 from ..components import CropByBoxes
 from .utils import get_neighbor_boxes_idx
 from .utils import get_neighbor_boxes_idx
@@ -259,6 +260,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         elif task == "det":
         elif task == "det":
             threshold = 0.0
             threshold = 0.0
             result = []
             result = []
+            cell_score = []
             if "boxes" in pred and isinstance(pred["boxes"], list):
             if "boxes" in pred and isinstance(pred["boxes"], list):
                 for box in pred["boxes"]:
                 for box in pred["boxes"]:
                     if isinstance(box, dict) and "score" in box and "coordinate" in box:
                     if isinstance(box, dict) and "score" in box and "coordinate" in box:
@@ -266,11 +268,229 @@ class TableRecognitionPipelineV2(BasePipeline):
                         coordinate = box["coordinate"]
                         coordinate = box["coordinate"]
                         if isinstance(score, float) and score > threshold:
                         if isinstance(score, float) and score > threshold:
                             result.append(coordinate)
                             result.append(coordinate)
-            return result
+                            cell_score.append(score)
+            return result, cell_score
         elif task == "table_stru":
         elif task == "table_stru":
             return pred["structure"]
             return pred["structure"]
         else:
         else:
             return None
             return None
+    
+    def cells_det_results_nms(self, cells_det_results, cells_det_scores, cells_det_threshold=0.3):
+        """
+        Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
+
+        Args:
+            cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
+            cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
+            cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
+                                        will be suppressed. Default is 0.5.
+
+        Returns:
+        Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
+                            while maintaining one-to-one correspondence.
+        """
+        # Convert lists to numpy arrays for efficient computation
+        boxes = np.array(cells_det_results)
+        scores = np.array(cells_det_scores)
+        # Initialize list for picked indices
+        picked_indices = []
+        # Get coordinates of bounding boxes
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        # Compute the area of the bounding boxes
+        areas = (x2 - x1) * (y2 - y1)
+        # Sort the bounding boxes by the confidence scores in descending order
+        order = scores.argsort()[::-1]
+        # Process the boxes
+        while order.size > 0:
+            # Index of the current highest score box
+            i = order[0]
+            picked_indices.append(i)
+            # Compute IoU between the highest score box and the rest
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            # Compute the width and height of the overlapping area
+            w = np.maximum(0.0, xx2 - xx1)
+            h = np.maximum(0.0, yy2 - yy1)
+            # Compute the ratio of overlap (IoU)
+            inter = w * h
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            # Indices of boxes with IoU less than threshold
+            inds = np.where(ovr <= cells_det_threshold)[0]
+            # Update order, only keep boxes with IoU less than threshold
+            order = order[inds + 1]  # inds shifted by 1 because order[0] is the current box
+        # Select the boxes and scores based on picked indices
+        final_boxes = boxes[picked_indices].tolist()
+        final_scores = scores[picked_indices].tolist()
+        return final_boxes, final_scores
+    
+    def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
+        """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
+
+        Args:
+            ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
+            table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
+
+        Returns:
+            list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
+        """
+        tol = 5
+        # Extract coordinates from table_box
+        x_min_t, y_min_t, x_max_t, y_max_t = table_box
+
+        adjusted_boxes = []
+        for box in ocr_det_boxes:
+            x_min_b, y_min_b, x_max_b, y_max_b = box
+
+            # Check if the box is fully inside table_box
+            if (x_min_b+tol >= x_min_t and y_min_b+tol >= y_min_t and
+                x_max_b+tol <= x_max_t and y_max_b+tol <= y_max_t):
+                # Adjust the coordinates to be relative to table_box
+                adjusted_box = [
+                    x_min_b - x_min_t,  # Adjust x1
+                    y_min_b - y_min_t,  # Adjust y1
+                    x_max_b - x_min_t,  # Adjust x2
+                    y_max_b - y_min_t   # Adjust y2
+                ]
+                adjusted_boxes.append(adjusted_box)
+            # Discard boxes not fully inside table_box
+        return adjusted_boxes
+
+    def cells_det_results_reprocessing(self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums):
+        """
+        Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
+
+        Args:
+            cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
+            cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
+            ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
+            html_pred_boxes_nums (int): The desired number of rectangles in the final output.
+
+        Returns:
+            List[List[float]]: The processed list of rectangles.
+        """
+        # Function to compute IoU between two rectangles
+        def compute_iou(box1, box2):
+            """
+            Compute the Intersection over Union (IoU) between two rectangles.
+
+            Args:
+                box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
+                box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
+
+            Returns:
+                float: The IoU between the two rectangles.
+            """
+            # Determine the coordinates of the intersection rectangle
+            x_left = max(box1[0], box2[0])
+            y_top = max(box1[1], box2[1])
+            x_right = min(box1[2], box2[2])
+            y_bottom = min(box1[3], box2[3])
+            if x_right <= x_left or y_bottom <= y_top:
+                return 0.0
+            # Calculate the area of intersection rectangle
+            intersection_area = (x_right - x_left) * (y_bottom - y_top)
+            # Calculate the area of both rectangles
+            box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
+            box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
+            # Calculate the IoU
+            iou = intersection_area / float(box1_area + box2_area - intersection_area)
+            return iou
+
+        # Function to combine rectangles into N rectangles
+        def combine_rectangles(rectangles, N):
+            """
+            Combine rectangles into N rectangles based on geometric proximity.
+
+            Args:
+            rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
+            N (int): The desired number of combined rectangles.
+
+            Returns:
+            list of list of int: A list of N combined rectangles.
+            """
+            # Number of input rectangles
+            num_rects = len(rectangles)
+            # If N is greater than or equal to the number of rectangles, return the original rectangles
+            if N >= num_rects:
+                return rectangles
+            # Compute the center points of the rectangles
+            centers = np.array([
+                [
+                    (rect[0] + rect[2]) / 2,  # Center x-coordinate
+                    (rect[1] + rect[3]) / 2   # Center y-coordinate
+                ]
+                for rect in rectangles
+            ])
+            # Perform KMeans clustering on the center points to group them into N clusters
+            kmeans = KMeans(n_clusters=N, random_state=0, n_init='auto')
+            labels = kmeans.fit_predict(centers)
+            # Initialize a list to store the combined rectangles
+            combined_rectangles = []
+            # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
+            for i in range(N):
+                # Get the indices of rectangles that belong to cluster i
+                indices = np.where(labels == i)[0]
+                if len(indices) == 0:
+                    # If no rectangles in this cluster, skip it
+                    continue
+                # Extract the rectangles in cluster i
+                cluster_rects = np.array([rectangles[idx] for idx in indices])
+                # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
+                x1_min = np.min(cluster_rects[:, 0])
+                y1_min = np.min(cluster_rects[:, 1])
+                x2_max = np.max(cluster_rects[:, 2])
+                y2_max = np.max(cluster_rects[:, 3])
+                # Append the combined rectangle to the list
+                combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
+            return combined_rectangles
+
+        # Ensure that the inputs are numpy arrays for efficient computation
+        cells_det_results = np.array(cells_det_results)
+        cells_det_scores = np.array(cells_det_scores)
+        ocr_det_results = np.array(ocr_det_results)
+        if len(cells_det_results) == html_pred_boxes_nums:
+            return cells_det_results
+        # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
+        elif len(cells_det_results) > html_pred_boxes_nums:
+                return combine_rectangles(cells_det_results, html_pred_boxes_nums)
+        else:
+            # return cells_det_results
+            # Threshold for IoU
+            iou_threshold = 0.1
+            # List to store ocr_miss_boxes
+            ocr_miss_boxes = []
+            # For each rectangle in ocr_det_results
+            for ocr_rect in ocr_det_results:
+                # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
+                has_large_iou = False
+                # For each rectangle in cells_det_results
+                for cell_rect in cells_det_results:
+                    # Compute IoU
+                    iou = compute_iou(ocr_rect, cell_rect)
+                    if iou >= iou_threshold:
+                        has_large_iou = True
+                        break
+                if not has_large_iou:
+                    ocr_miss_boxes.append(ocr_rect)
+            # If no ocr_miss_boxes, return cells_det_results
+            if len(ocr_miss_boxes) == 0:
+                return cells_det_results.tolist()
+            else:
+                # Need to combine ocr_miss_boxes into N rectangles
+                N = html_pred_boxes_nums - len(cells_det_results)
+                if len(ocr_miss_boxes) == N:
+                    return cells_det_results.tolist() + ocr_miss_boxes
+                else:
+                    # Combine ocr_miss_boxes into N rectangles
+                    ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
+                    # Combine cells_det_results and ocr_supp_boxes
+                    final_results = np.concatenate((cells_det_results, ocr_supp_boxes), axis=0)
+                    return final_results.tolist()
 
 
     def predict_single_table_recognition_res(
     def predict_single_table_recognition_res(
         self,
         self,
@@ -295,16 +515,25 @@ class TableRecognitionPipelineV2(BasePipeline):
         table_cls_result = self.extract_results(table_cls_pred, "cls")
         table_cls_result = self.extract_results(table_cls_pred, "cls")
         if table_cls_result == "wired_table":
         if table_cls_result == "wired_table":
             table_structure_pred = next(self.wired_table_rec_model(image_array))
             table_structure_pred = next(self.wired_table_rec_model(image_array))
-            table_cells_pred = next(self.wired_table_cells_detection_model(image_array))
+            table_cells_pred = next(
+                self.wired_table_cells_detection_model(image_array, threshold=0.3)
+            ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
+              # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
         elif table_cls_result == "wireless_table":
         elif table_cls_result == "wireless_table":
             table_structure_pred = next(self.wireless_table_rec_model(image_array))
             table_structure_pred = next(self.wireless_table_rec_model(image_array))
             table_cells_pred = next(
             table_cells_pred = next(
-                self.wireless_table_cells_detection_model(image_array)
-            )
+                self.wireless_table_cells_detection_model(image_array, threshold=0.3)
+            ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
+              # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
         table_structure_result = self.extract_results(
         table_structure_result = self.extract_results(
             table_structure_pred, "table_stru"
             table_structure_pred, "table_stru"
         )
         )
-        table_cells_result = self.extract_results(table_cells_pred, "det")
+        table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
+        table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
+        ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
+        table_cells_result = self.cells_det_results_reprocessing(
+            table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
+        )
         single_table_recognition_res = get_table_recognition_res(
         single_table_recognition_res = get_table_recognition_res(
             table_box, table_structure_result, table_cells_result, overall_ocr_res
             table_box, table_structure_result, table_cells_result, overall_ocr_res
         )
         )
@@ -316,7 +545,7 @@ class TableRecognitionPipelineV2(BasePipeline):
             if len(match_idx_list) > 0:
             if len(match_idx_list) > 0:
                 for idx in match_idx_list:
                 for idx in match_idx_list:
                     neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
                     neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
-        single_table_recognition_res["neighbor_text"] = neighbor_text
+        single_table_recognition_res["neighbor_texts"] = neighbor_text
         return single_table_recognition_res
         return single_table_recognition_res
 
 
     def predict(
     def predict(

+ 27 - 38
paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

@@ -156,7 +156,6 @@ def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
             matched[distances.index(sorted_distances[0])].append(i)
             matched[distances.index(sorted_distances[0])].append(i)
     return matched
     return matched
 
 
-
 def get_html_result(
 def get_html_result(
     matched_index: dict, ocr_contents: dict, pred_structures: list
     matched_index: dict, ocr_contents: dict, pred_structures: list
 ) -> str:
 ) -> str:
@@ -181,6 +180,8 @@ def get_html_result(
             if "<td></td>" == tag:
             if "<td></td>" == tag:
                 pred_html.extend("<td>")
                 pred_html.extend("<td>")
             if td_index in matched_index.keys():
             if td_index in matched_index.keys():
+                if len(matched_index[td_index])==0:
+                    continue
                 b_with = False
                 b_with = False
                 if (
                 if (
                     "<b>" in ocr_contents[matched_index[td_index][0]]
                     "<b>" in ocr_contents[matched_index[td_index][0]]
@@ -218,10 +219,9 @@ def get_html_result(
     html += "".join(end_structure)
     html += "".join(end_structure)
     return html
     return html
 
 
-
 def sort_table_cells_boxes(boxes):
 def sort_table_cells_boxes(boxes):
     """
     """
-    Sort the input list of bounding boxes by using the DBSCAN algorithm to cluster based on the top-left y-coordinate (y1), and then sort within each line from left to right based on the top-left x-coordinate (x1).
+    Sort the input list of bounding boxes.
 
 
     Args:
     Args:
         boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
         boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
@@ -229,43 +229,31 @@ def sort_table_cells_boxes(boxes):
     Returns:
     Returns:
         sorted_boxes (list of lists): The list of bounding boxes sorted.
         sorted_boxes (list of lists): The list of bounding boxes sorted.
     """
     """
-    import numpy as np
-    from sklearn.cluster import DBSCAN
-
-    # Extract the top-left y-coordinates (y1)
-    y1_coords = np.array([box[1] for box in boxes])
-    y1_coords = y1_coords.reshape(-1, 1)  # Convert to a 2D array
-
-    # Choose an appropriate eps parameter based on the range of y-values
-    y_range = y1_coords.max() - y1_coords.min()
-    eps = y_range / 50  # Adjust the denominator as needed
-
-    # Perform clustering using DBSCAN
-    db = DBSCAN(eps=eps, min_samples=1).fit(y1_coords)
-    labels = db.labels_
-
-    # Group bounding boxes by their labels
-    clusters = {}
-    for label, box in zip(labels, boxes):
-        if label not in clusters:
-            clusters[label] = []
-        clusters[label].append(box)
-
-    # Sort rows based on y-coordinates
-    # Compute the average y1 value for each row and sort from top to bottom
-    sorted_rows = sorted(
-        clusters.items(), key=lambda item: np.mean([box[1] for box in item[1]])
-    )
-
-    # Within each row, sort by x1 coordinate
-    sorted_boxes = []
-    for label, row in sorted_rows:
-        row_sorted = sorted(row, key=lambda x: x[0])
-        sorted_boxes.extend(row_sorted)
 
 
+    boxes_sorted_by_y = sorted(boxes, key=lambda box: box[1])
+    rows = []
+    current_row = []
+    current_y = None
+    tolerance = 10
+    for box in boxes_sorted_by_y:
+        x1, y1, x2, y2 = box
+        if current_y is None:
+            current_row.append(box)
+            current_y = y1
+        else:
+            if abs(y1 - current_y) <= tolerance:
+                current_row.append(box)
+            else:
+                current_row.sort(key=lambda x: x[0])
+                rows.append(current_row)
+                current_row = [box]
+                current_y = y1
+    if current_row:
+        current_row.sort(key=lambda x: x[0])
+        rows.append(current_row)
+    sorted_boxes = [box for row in rows for box in row] 
     return sorted_boxes
     return sorted_boxes
 
 
-
 def convert_to_four_point_coordinates(boxes):
 def convert_to_four_point_coordinates(boxes):
     """
     """
     Convert bounding boxes from [x1, y1, x2, y2] format to 
     Convert bounding boxes from [x1, y1, x2, y2] format to 
@@ -324,7 +312,8 @@ def get_table_recognition_res(
     Returns:
     Returns:
         SingleTableRecognitionResult: An object containing the single table recognition result.
         SingleTableRecognitionResult: An object containing the single table recognition result.
     """
     """
-    table_cells_result =convert_to_four_point_coordinates(table_cells_result)
+
+    table_cells_result = convert_to_four_point_coordinates(table_cells_result)
 
 
     table_box = np.array([table_box])
     table_box = np.array([table_box])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)

+ 1 - 0
paddlex/utils/pipeline_arguments.py

@@ -143,6 +143,7 @@ PIPELINE_ARGUMENTS = {
         },
         },
     ],
     ],
     "table_recognition": None,
     "table_recognition": None,
+    "table_recognition_v2": None,
     "seal_recognition": [
     "seal_recognition": [
         {
         {
             "name": "--use_doc_orientation_classify",
             "name": "--use_doc_orientation_classify",