瀏覽代碼

Add new table related algorithm (#3229)

* fix bugs

* refine codes

* fix bugs

* refine code

* add new algorithm

* refine codes

* refine codes
Liu Jiaxuan 9 月之前
父節點
當前提交
e8d2d34824

+ 1 - 0
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -15,6 +15,7 @@ from __future__ import annotations
 
 import copy
 from pathlib import Path
+from PIL import Image, ImageDraw
 from typing import Dict
 
 import cv2

+ 235 - 6
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -16,6 +16,7 @@ import os, sys
 from typing import Any, Dict, Optional, Union, List, Tuple
 import numpy as np
 import cv2
+from sklearn.cluster import KMeans
 from ..base import BasePipeline
 from ..components import CropByBoxes
 from .utils import get_neighbor_boxes_idx
@@ -259,6 +260,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         elif task == "det":
             threshold = 0.0
             result = []
+            cell_score = []
             if "boxes" in pred and isinstance(pred["boxes"], list):
                 for box in pred["boxes"]:
                     if isinstance(box, dict) and "score" in box and "coordinate" in box:
@@ -266,11 +268,229 @@ class TableRecognitionPipelineV2(BasePipeline):
                         coordinate = box["coordinate"]
                         if isinstance(score, float) and score > threshold:
                             result.append(coordinate)
-            return result
+                            cell_score.append(score)
+            return result, cell_score
         elif task == "table_stru":
             return pred["structure"]
         else:
             return None
+    
+    def cells_det_results_nms(self, cells_det_results, cells_det_scores, cells_det_threshold=0.3):
+        """
+        Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
+
+        Args:
+            cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
+            cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
+            cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
+                                        will be suppressed. Default is 0.5.
+
+        Returns:
+        Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
+                            while maintaining one-to-one correspondence.
+        """
+        # Convert lists to numpy arrays for efficient computation
+        boxes = np.array(cells_det_results)
+        scores = np.array(cells_det_scores)
+        # Initialize list for picked indices
+        picked_indices = []
+        # Get coordinates of bounding boxes
+        x1 = boxes[:, 0]
+        y1 = boxes[:, 1]
+        x2 = boxes[:, 2]
+        y2 = boxes[:, 3]
+        # Compute the area of the bounding boxes
+        areas = (x2 - x1) * (y2 - y1)
+        # Sort the bounding boxes by the confidence scores in descending order
+        order = scores.argsort()[::-1]
+        # Process the boxes
+        while order.size > 0:
+            # Index of the current highest score box
+            i = order[0]
+            picked_indices.append(i)
+            # Compute IoU between the highest score box and the rest
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            # Compute the width and height of the overlapping area
+            w = np.maximum(0.0, xx2 - xx1)
+            h = np.maximum(0.0, yy2 - yy1)
+            # Compute the ratio of overlap (IoU)
+            inter = w * h
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            # Indices of boxes with IoU less than threshold
+            inds = np.where(ovr <= cells_det_threshold)[0]
+            # Update order, only keep boxes with IoU less than threshold
+            order = order[inds + 1]  # inds shifted by 1 because order[0] is the current box
+        # Select the boxes and scores based on picked indices
+        final_boxes = boxes[picked_indices].tolist()
+        final_scores = scores[picked_indices].tolist()
+        return final_boxes, final_scores
+    
+    def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
+        """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
+
+        Args:
+            ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
+            table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
+
+        Returns:
+            list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
+        """
+        tol = 5
+        # Extract coordinates from table_box
+        x_min_t, y_min_t, x_max_t, y_max_t = table_box
+
+        adjusted_boxes = []
+        for box in ocr_det_boxes:
+            x_min_b, y_min_b, x_max_b, y_max_b = box
+
+            # Check if the box is fully inside table_box
+            if (x_min_b+tol >= x_min_t and y_min_b+tol >= y_min_t and
+                x_max_b+tol <= x_max_t and y_max_b+tol <= y_max_t):
+                # Adjust the coordinates to be relative to table_box
+                adjusted_box = [
+                    x_min_b - x_min_t,  # Adjust x1
+                    y_min_b - y_min_t,  # Adjust y1
+                    x_max_b - x_min_t,  # Adjust x2
+                    y_max_b - y_min_t   # Adjust y2
+                ]
+                adjusted_boxes.append(adjusted_box)
+            # Discard boxes not fully inside table_box
+        return adjusted_boxes
+
+    def cells_det_results_reprocessing(self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums):
+        """
+        Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
+
+        Args:
+            cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
+            cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
+            ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
+            html_pred_boxes_nums (int): The desired number of rectangles in the final output.
+
+        Returns:
+            List[List[float]]: The processed list of rectangles.
+        """
+        # Function to compute IoU between two rectangles
+        def compute_iou(box1, box2):
+            """
+            Compute the Intersection over Union (IoU) between two rectangles.
+
+            Args:
+                box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
+                box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
+
+            Returns:
+                float: The IoU between the two rectangles.
+            """
+            # Determine the coordinates of the intersection rectangle
+            x_left = max(box1[0], box2[0])
+            y_top = max(box1[1], box2[1])
+            x_right = min(box1[2], box2[2])
+            y_bottom = min(box1[3], box2[3])
+            if x_right <= x_left or y_bottom <= y_top:
+                return 0.0
+            # Calculate the area of intersection rectangle
+            intersection_area = (x_right - x_left) * (y_bottom - y_top)
+            # Calculate the area of both rectangles
+            box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
+            box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
+            # Calculate the IoU
+            iou = intersection_area / float(box1_area + box2_area - intersection_area)
+            return iou
+
+        # Function to combine rectangles into N rectangles
+        def combine_rectangles(rectangles, N):
+            """
+            Combine rectangles into N rectangles based on geometric proximity.
+
+            Args:
+            rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
+            N (int): The desired number of combined rectangles.
+
+            Returns:
+            list of list of int: A list of N combined rectangles.
+            """
+            # Number of input rectangles
+            num_rects = len(rectangles)
+            # If N is greater than or equal to the number of rectangles, return the original rectangles
+            if N >= num_rects:
+                return rectangles
+            # Compute the center points of the rectangles
+            centers = np.array([
+                [
+                    (rect[0] + rect[2]) / 2,  # Center x-coordinate
+                    (rect[1] + rect[3]) / 2   # Center y-coordinate
+                ]
+                for rect in rectangles
+            ])
+            # Perform KMeans clustering on the center points to group them into N clusters
+            kmeans = KMeans(n_clusters=N, random_state=0, n_init='auto')
+            labels = kmeans.fit_predict(centers)
+            # Initialize a list to store the combined rectangles
+            combined_rectangles = []
+            # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
+            for i in range(N):
+                # Get the indices of rectangles that belong to cluster i
+                indices = np.where(labels == i)[0]
+                if len(indices) == 0:
+                    # If no rectangles in this cluster, skip it
+                    continue
+                # Extract the rectangles in cluster i
+                cluster_rects = np.array([rectangles[idx] for idx in indices])
+                # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
+                x1_min = np.min(cluster_rects[:, 0])
+                y1_min = np.min(cluster_rects[:, 1])
+                x2_max = np.max(cluster_rects[:, 2])
+                y2_max = np.max(cluster_rects[:, 3])
+                # Append the combined rectangle to the list
+                combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
+            return combined_rectangles
+
+        # Ensure that the inputs are numpy arrays for efficient computation
+        cells_det_results = np.array(cells_det_results)
+        cells_det_scores = np.array(cells_det_scores)
+        ocr_det_results = np.array(ocr_det_results)
+        if len(cells_det_results) == html_pred_boxes_nums:
+            return cells_det_results
+        # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
+        elif len(cells_det_results) > html_pred_boxes_nums:
+                return combine_rectangles(cells_det_results, html_pred_boxes_nums)
+        else:
+            # return cells_det_results
+            # Threshold for IoU
+            iou_threshold = 0.1
+            # List to store ocr_miss_boxes
+            ocr_miss_boxes = []
+            # For each rectangle in ocr_det_results
+            for ocr_rect in ocr_det_results:
+                # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
+                has_large_iou = False
+                # For each rectangle in cells_det_results
+                for cell_rect in cells_det_results:
+                    # Compute IoU
+                    iou = compute_iou(ocr_rect, cell_rect)
+                    if iou >= iou_threshold:
+                        has_large_iou = True
+                        break
+                if not has_large_iou:
+                    ocr_miss_boxes.append(ocr_rect)
+            # If no ocr_miss_boxes, return cells_det_results
+            if len(ocr_miss_boxes) == 0:
+                return cells_det_results.tolist()
+            else:
+                # Need to combine ocr_miss_boxes into N rectangles
+                N = html_pred_boxes_nums - len(cells_det_results)
+                if len(ocr_miss_boxes) == N:
+                    return cells_det_results.tolist() + ocr_miss_boxes
+                else:
+                    # Combine ocr_miss_boxes into N rectangles
+                    ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
+                    # Combine cells_det_results and ocr_supp_boxes
+                    final_results = np.concatenate((cells_det_results, ocr_supp_boxes), axis=0)
+                    return final_results.tolist()
 
     def predict_single_table_recognition_res(
         self,
@@ -295,16 +515,25 @@ class TableRecognitionPipelineV2(BasePipeline):
         table_cls_result = self.extract_results(table_cls_pred, "cls")
         if table_cls_result == "wired_table":
             table_structure_pred = next(self.wired_table_rec_model(image_array))
-            table_cells_pred = next(self.wired_table_cells_detection_model(image_array))
+            table_cells_pred = next(
+                self.wired_table_cells_detection_model(image_array, threshold=0.3)
+            ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
+              # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
         elif table_cls_result == "wireless_table":
             table_structure_pred = next(self.wireless_table_rec_model(image_array))
             table_cells_pred = next(
-                self.wireless_table_cells_detection_model(image_array)
-            )
+                self.wireless_table_cells_detection_model(image_array, threshold=0.3)
+            ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
+              # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
         table_structure_result = self.extract_results(
             table_structure_pred, "table_stru"
         )
-        table_cells_result = self.extract_results(table_cells_pred, "det")
+        table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
+        table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
+        ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
+        table_cells_result = self.cells_det_results_reprocessing(
+            table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
+        )
         single_table_recognition_res = get_table_recognition_res(
             table_box, table_structure_result, table_cells_result, overall_ocr_res
         )
@@ -316,7 +545,7 @@ class TableRecognitionPipelineV2(BasePipeline):
             if len(match_idx_list) > 0:
                 for idx in match_idx_list:
                     neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
-        single_table_recognition_res["neighbor_text"] = neighbor_text
+        single_table_recognition_res["neighbor_texts"] = neighbor_text
         return single_table_recognition_res
 
     def predict(

+ 27 - 38
paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

@@ -156,7 +156,6 @@ def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
             matched[distances.index(sorted_distances[0])].append(i)
     return matched
 
-
 def get_html_result(
     matched_index: dict, ocr_contents: dict, pred_structures: list
 ) -> str:
@@ -181,6 +180,8 @@ def get_html_result(
             if "<td></td>" == tag:
                 pred_html.extend("<td>")
             if td_index in matched_index.keys():
+                if len(matched_index[td_index])==0:
+                    continue
                 b_with = False
                 if (
                     "<b>" in ocr_contents[matched_index[td_index][0]]
@@ -218,10 +219,9 @@ def get_html_result(
     html += "".join(end_structure)
     return html
 
-
 def sort_table_cells_boxes(boxes):
     """
-    Sort the input list of bounding boxes by using the DBSCAN algorithm to cluster based on the top-left y-coordinate (y1), and then sort within each line from left to right based on the top-left x-coordinate (x1).
+    Sort the input list of bounding boxes.
 
     Args:
         boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
@@ -229,43 +229,31 @@ def sort_table_cells_boxes(boxes):
     Returns:
         sorted_boxes (list of lists): The list of bounding boxes sorted.
     """
-    import numpy as np
-    from sklearn.cluster import DBSCAN
-
-    # Extract the top-left y-coordinates (y1)
-    y1_coords = np.array([box[1] for box in boxes])
-    y1_coords = y1_coords.reshape(-1, 1)  # Convert to a 2D array
-
-    # Choose an appropriate eps parameter based on the range of y-values
-    y_range = y1_coords.max() - y1_coords.min()
-    eps = y_range / 50  # Adjust the denominator as needed
-
-    # Perform clustering using DBSCAN
-    db = DBSCAN(eps=eps, min_samples=1).fit(y1_coords)
-    labels = db.labels_
-
-    # Group bounding boxes by their labels
-    clusters = {}
-    for label, box in zip(labels, boxes):
-        if label not in clusters:
-            clusters[label] = []
-        clusters[label].append(box)
-
-    # Sort rows based on y-coordinates
-    # Compute the average y1 value for each row and sort from top to bottom
-    sorted_rows = sorted(
-        clusters.items(), key=lambda item: np.mean([box[1] for box in item[1]])
-    )
-
-    # Within each row, sort by x1 coordinate
-    sorted_boxes = []
-    for label, row in sorted_rows:
-        row_sorted = sorted(row, key=lambda x: x[0])
-        sorted_boxes.extend(row_sorted)
 
+    boxes_sorted_by_y = sorted(boxes, key=lambda box: box[1])
+    rows = []
+    current_row = []
+    current_y = None
+    tolerance = 10
+    for box in boxes_sorted_by_y:
+        x1, y1, x2, y2 = box
+        if current_y is None:
+            current_row.append(box)
+            current_y = y1
+        else:
+            if abs(y1 - current_y) <= tolerance:
+                current_row.append(box)
+            else:
+                current_row.sort(key=lambda x: x[0])
+                rows.append(current_row)
+                current_row = [box]
+                current_y = y1
+    if current_row:
+        current_row.sort(key=lambda x: x[0])
+        rows.append(current_row)
+    sorted_boxes = [box for row in rows for box in row] 
     return sorted_boxes
 
-
 def convert_to_four_point_coordinates(boxes):
     """
     Convert bounding boxes from [x1, y1, x2, y2] format to 
@@ -324,7 +312,8 @@ def get_table_recognition_res(
     Returns:
         SingleTableRecognitionResult: An object containing the single table recognition result.
     """
-    table_cells_result =convert_to_four_point_coordinates(table_cells_result)
+
+    table_cells_result = convert_to_four_point_coordinates(table_cells_result)
 
     table_box = np.array([table_box])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)

+ 1 - 0
paddlex/utils/pipeline_arguments.py

@@ -143,6 +143,7 @@ PIPELINE_ARGUMENTS = {
         },
     ],
     "table_recognition": None,
+    "table_recognition_v2": None,
     "seal_recognition": [
         {
             "name": "--use_doc_orientation_classify",