|
@@ -16,6 +16,7 @@ import os, sys
|
|
|
from typing import Any, Dict, Optional, Union, List, Tuple
|
|
from typing import Any, Dict, Optional, Union, List, Tuple
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import cv2
|
|
import cv2
|
|
|
|
|
+from sklearn.cluster import KMeans
|
|
|
from ..base import BasePipeline
|
|
from ..base import BasePipeline
|
|
|
from ..components import CropByBoxes
|
|
from ..components import CropByBoxes
|
|
|
from .utils import get_neighbor_boxes_idx
|
|
from .utils import get_neighbor_boxes_idx
|
|
@@ -259,6 +260,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
elif task == "det":
|
|
elif task == "det":
|
|
|
threshold = 0.0
|
|
threshold = 0.0
|
|
|
result = []
|
|
result = []
|
|
|
|
|
+ cell_score = []
|
|
|
if "boxes" in pred and isinstance(pred["boxes"], list):
|
|
if "boxes" in pred and isinstance(pred["boxes"], list):
|
|
|
for box in pred["boxes"]:
|
|
for box in pred["boxes"]:
|
|
|
if isinstance(box, dict) and "score" in box and "coordinate" in box:
|
|
if isinstance(box, dict) and "score" in box and "coordinate" in box:
|
|
@@ -266,11 +268,229 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
coordinate = box["coordinate"]
|
|
coordinate = box["coordinate"]
|
|
|
if isinstance(score, float) and score > threshold:
|
|
if isinstance(score, float) and score > threshold:
|
|
|
result.append(coordinate)
|
|
result.append(coordinate)
|
|
|
- return result
|
|
|
|
|
|
|
+ cell_score.append(score)
|
|
|
|
|
+ return result, cell_score
|
|
|
elif task == "table_stru":
|
|
elif task == "table_stru":
|
|
|
return pred["structure"]
|
|
return pred["structure"]
|
|
|
else:
|
|
else:
|
|
|
return None
|
|
return None
|
|
|
|
|
+
|
|
|
|
|
+ def cells_det_results_nms(self, cells_det_results, cells_det_scores, cells_det_threshold=0.3):
|
|
|
|
|
+ """
|
|
|
|
|
+ Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
|
|
|
|
|
+ cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
|
|
|
|
|
+ cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
|
|
|
|
|
+ will be suppressed. Default is 0.5.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
|
|
|
|
|
+ while maintaining one-to-one correspondence.
|
|
|
|
|
+ """
|
|
|
|
|
+ # Convert lists to numpy arrays for efficient computation
|
|
|
|
|
+ boxes = np.array(cells_det_results)
|
|
|
|
|
+ scores = np.array(cells_det_scores)
|
|
|
|
|
+ # Initialize list for picked indices
|
|
|
|
|
+ picked_indices = []
|
|
|
|
|
+ # Get coordinates of bounding boxes
|
|
|
|
|
+ x1 = boxes[:, 0]
|
|
|
|
|
+ y1 = boxes[:, 1]
|
|
|
|
|
+ x2 = boxes[:, 2]
|
|
|
|
|
+ y2 = boxes[:, 3]
|
|
|
|
|
+ # Compute the area of the bounding boxes
|
|
|
|
|
+ areas = (x2 - x1) * (y2 - y1)
|
|
|
|
|
+ # Sort the bounding boxes by the confidence scores in descending order
|
|
|
|
|
+ order = scores.argsort()[::-1]
|
|
|
|
|
+ # Process the boxes
|
|
|
|
|
+ while order.size > 0:
|
|
|
|
|
+ # Index of the current highest score box
|
|
|
|
|
+ i = order[0]
|
|
|
|
|
+ picked_indices.append(i)
|
|
|
|
|
+ # Compute IoU between the highest score box and the rest
|
|
|
|
|
+ xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
|
|
|
+ yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
|
|
|
+ xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
|
|
|
+ yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
|
|
+ # Compute the width and height of the overlapping area
|
|
|
|
|
+ w = np.maximum(0.0, xx2 - xx1)
|
|
|
|
|
+ h = np.maximum(0.0, yy2 - yy1)
|
|
|
|
|
+ # Compute the ratio of overlap (IoU)
|
|
|
|
|
+ inter = w * h
|
|
|
|
|
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
|
|
+ # Indices of boxes with IoU less than threshold
|
|
|
|
|
+ inds = np.where(ovr <= cells_det_threshold)[0]
|
|
|
|
|
+ # Update order, only keep boxes with IoU less than threshold
|
|
|
|
|
+ order = order[inds + 1] # inds shifted by 1 because order[0] is the current box
|
|
|
|
|
+ # Select the boxes and scores based on picked indices
|
|
|
|
|
+ final_boxes = boxes[picked_indices].tolist()
|
|
|
|
|
+ final_scores = scores[picked_indices].tolist()
|
|
|
|
|
+ return final_boxes, final_scores
|
|
|
|
|
+
|
|
|
|
|
+ def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
|
|
|
|
|
+ """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
|
|
|
|
|
+ table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
|
|
|
|
|
+ """
|
|
|
|
|
+ tol = 5
|
|
|
|
|
+ # Extract coordinates from table_box
|
|
|
|
|
+ x_min_t, y_min_t, x_max_t, y_max_t = table_box
|
|
|
|
|
+
|
|
|
|
|
+ adjusted_boxes = []
|
|
|
|
|
+ for box in ocr_det_boxes:
|
|
|
|
|
+ x_min_b, y_min_b, x_max_b, y_max_b = box
|
|
|
|
|
+
|
|
|
|
|
+ # Check if the box is fully inside table_box
|
|
|
|
|
+ if (x_min_b+tol >= x_min_t and y_min_b+tol >= y_min_t and
|
|
|
|
|
+ x_max_b+tol <= x_max_t and y_max_b+tol <= y_max_t):
|
|
|
|
|
+ # Adjust the coordinates to be relative to table_box
|
|
|
|
|
+ adjusted_box = [
|
|
|
|
|
+ x_min_b - x_min_t, # Adjust x1
|
|
|
|
|
+ y_min_b - y_min_t, # Adjust y1
|
|
|
|
|
+ x_max_b - x_min_t, # Adjust x2
|
|
|
|
|
+ y_max_b - y_min_t # Adjust y2
|
|
|
|
|
+ ]
|
|
|
|
|
+ adjusted_boxes.append(adjusted_box)
|
|
|
|
|
+ # Discard boxes not fully inside table_box
|
|
|
|
|
+ return adjusted_boxes
|
|
|
|
|
+
|
|
|
|
|
+ def cells_det_results_reprocessing(self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums):
|
|
|
|
|
+ """
|
|
|
|
|
+ Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
|
|
|
|
|
+ cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
|
|
|
|
|
+ ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
|
|
|
|
|
+ html_pred_boxes_nums (int): The desired number of rectangles in the final output.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List[List[float]]: The processed list of rectangles.
|
|
|
|
|
+ """
|
|
|
|
|
+ # Function to compute IoU between two rectangles
|
|
|
|
|
+ def compute_iou(box1, box2):
|
|
|
|
|
+ """
|
|
|
|
|
+ Compute the Intersection over Union (IoU) between two rectangles.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
|
|
|
|
|
+ box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ float: The IoU between the two rectangles.
|
|
|
|
|
+ """
|
|
|
|
|
+ # Determine the coordinates of the intersection rectangle
|
|
|
|
|
+ x_left = max(box1[0], box2[0])
|
|
|
|
|
+ y_top = max(box1[1], box2[1])
|
|
|
|
|
+ x_right = min(box1[2], box2[2])
|
|
|
|
|
+ y_bottom = min(box1[3], box2[3])
|
|
|
|
|
+ if x_right <= x_left or y_bottom <= y_top:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+ # Calculate the area of intersection rectangle
|
|
|
|
|
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
|
|
|
|
+ # Calculate the area of both rectangles
|
|
|
|
|
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
|
|
|
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
|
|
|
+ # Calculate the IoU
|
|
|
|
|
+ iou = intersection_area / float(box1_area + box2_area - intersection_area)
|
|
|
|
|
+ return iou
|
|
|
|
|
+
|
|
|
|
|
+ # Function to combine rectangles into N rectangles
|
|
|
|
|
+ def combine_rectangles(rectangles, N):
|
|
|
|
|
+ """
|
|
|
|
|
+ Combine rectangles into N rectangles based on geometric proximity.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
|
|
|
|
|
+ N (int): The desired number of combined rectangles.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list of list of int: A list of N combined rectangles.
|
|
|
|
|
+ """
|
|
|
|
|
+ # Number of input rectangles
|
|
|
|
|
+ num_rects = len(rectangles)
|
|
|
|
|
+ # If N is greater than or equal to the number of rectangles, return the original rectangles
|
|
|
|
|
+ if N >= num_rects:
|
|
|
|
|
+ return rectangles
|
|
|
|
|
+ # Compute the center points of the rectangles
|
|
|
|
|
+ centers = np.array([
|
|
|
|
|
+ [
|
|
|
|
|
+ (rect[0] + rect[2]) / 2, # Center x-coordinate
|
|
|
|
|
+ (rect[1] + rect[3]) / 2 # Center y-coordinate
|
|
|
|
|
+ ]
|
|
|
|
|
+ for rect in rectangles
|
|
|
|
|
+ ])
|
|
|
|
|
+ # Perform KMeans clustering on the center points to group them into N clusters
|
|
|
|
|
+ kmeans = KMeans(n_clusters=N, random_state=0, n_init='auto')
|
|
|
|
|
+ labels = kmeans.fit_predict(centers)
|
|
|
|
|
+ # Initialize a list to store the combined rectangles
|
|
|
|
|
+ combined_rectangles = []
|
|
|
|
|
+ # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
|
|
|
|
|
+ for i in range(N):
|
|
|
|
|
+ # Get the indices of rectangles that belong to cluster i
|
|
|
|
|
+ indices = np.where(labels == i)[0]
|
|
|
|
|
+ if len(indices) == 0:
|
|
|
|
|
+ # If no rectangles in this cluster, skip it
|
|
|
|
|
+ continue
|
|
|
|
|
+ # Extract the rectangles in cluster i
|
|
|
|
|
+ cluster_rects = np.array([rectangles[idx] for idx in indices])
|
|
|
|
|
+ # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
|
|
|
|
|
+ x1_min = np.min(cluster_rects[:, 0])
|
|
|
|
|
+ y1_min = np.min(cluster_rects[:, 1])
|
|
|
|
|
+ x2_max = np.max(cluster_rects[:, 2])
|
|
|
|
|
+ y2_max = np.max(cluster_rects[:, 3])
|
|
|
|
|
+ # Append the combined rectangle to the list
|
|
|
|
|
+ combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
|
|
|
|
|
+ return combined_rectangles
|
|
|
|
|
+
|
|
|
|
|
+ # Ensure that the inputs are numpy arrays for efficient computation
|
|
|
|
|
+ cells_det_results = np.array(cells_det_results)
|
|
|
|
|
+ cells_det_scores = np.array(cells_det_scores)
|
|
|
|
|
+ ocr_det_results = np.array(ocr_det_results)
|
|
|
|
|
+ if len(cells_det_results) == html_pred_boxes_nums:
|
|
|
|
|
+ return cells_det_results
|
|
|
|
|
+ # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
|
|
|
|
|
+ elif len(cells_det_results) > html_pred_boxes_nums:
|
|
|
|
|
+ return combine_rectangles(cells_det_results, html_pred_boxes_nums)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # return cells_det_results
|
|
|
|
|
+ # Threshold for IoU
|
|
|
|
|
+ iou_threshold = 0.1
|
|
|
|
|
+ # List to store ocr_miss_boxes
|
|
|
|
|
+ ocr_miss_boxes = []
|
|
|
|
|
+ # For each rectangle in ocr_det_results
|
|
|
|
|
+ for ocr_rect in ocr_det_results:
|
|
|
|
|
+ # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
|
|
|
|
|
+ has_large_iou = False
|
|
|
|
|
+ # For each rectangle in cells_det_results
|
|
|
|
|
+ for cell_rect in cells_det_results:
|
|
|
|
|
+ # Compute IoU
|
|
|
|
|
+ iou = compute_iou(ocr_rect, cell_rect)
|
|
|
|
|
+ if iou >= iou_threshold:
|
|
|
|
|
+ has_large_iou = True
|
|
|
|
|
+ break
|
|
|
|
|
+ if not has_large_iou:
|
|
|
|
|
+ ocr_miss_boxes.append(ocr_rect)
|
|
|
|
|
+ # If no ocr_miss_boxes, return cells_det_results
|
|
|
|
|
+ if len(ocr_miss_boxes) == 0:
|
|
|
|
|
+ return cells_det_results.tolist()
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Need to combine ocr_miss_boxes into N rectangles
|
|
|
|
|
+ N = html_pred_boxes_nums - len(cells_det_results)
|
|
|
|
|
+ if len(ocr_miss_boxes) == N:
|
|
|
|
|
+ return cells_det_results.tolist() + ocr_miss_boxes
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Combine ocr_miss_boxes into N rectangles
|
|
|
|
|
+ ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
|
|
|
|
|
+ # Combine cells_det_results and ocr_supp_boxes
|
|
|
|
|
+ final_results = np.concatenate((cells_det_results, ocr_supp_boxes), axis=0)
|
|
|
|
|
+ return final_results.tolist()
|
|
|
|
|
|
|
|
def predict_single_table_recognition_res(
|
|
def predict_single_table_recognition_res(
|
|
|
self,
|
|
self,
|
|
@@ -295,16 +515,25 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
if table_cls_result == "wired_table":
|
|
if table_cls_result == "wired_table":
|
|
|
table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
- table_cells_pred = next(self.wired_table_cells_detection_model(image_array))
|
|
|
|
|
|
|
+ table_cells_pred = next(
|
|
|
|
|
+ self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
elif table_cls_result == "wireless_table":
|
|
elif table_cls_result == "wireless_table":
|
|
|
table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
table_cells_pred = next(
|
|
table_cells_pred = next(
|
|
|
- self.wireless_table_cells_detection_model(image_array)
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ self.wireless_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
table_structure_result = self.extract_results(
|
|
table_structure_result = self.extract_results(
|
|
|
table_structure_pred, "table_stru"
|
|
table_structure_pred, "table_stru"
|
|
|
)
|
|
)
|
|
|
- table_cells_result = self.extract_results(table_cells_pred, "det")
|
|
|
|
|
|
|
+ table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
|
|
|
|
|
+ table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
|
|
|
|
|
+ ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
|
|
|
|
|
+ table_cells_result = self.cells_det_results_reprocessing(
|
|
|
|
|
+ table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
|
|
|
|
|
+ )
|
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
|
table_box, table_structure_result, table_cells_result, overall_ocr_res
|
|
table_box, table_structure_result, table_cells_result, overall_ocr_res
|
|
|
)
|
|
)
|
|
@@ -316,7 +545,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
if len(match_idx_list) > 0:
|
|
if len(match_idx_list) > 0:
|
|
|
for idx in match_idx_list:
|
|
for idx in match_idx_list:
|
|
|
neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
|
|
neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
|
|
|
- single_table_recognition_res["neighbor_text"] = neighbor_text
|
|
|
|
|
|
|
+ single_table_recognition_res["neighbor_texts"] = neighbor_text
|
|
|
return single_table_recognition_res
|
|
return single_table_recognition_res
|
|
|
|
|
|
|
|
def predict(
|
|
def predict(
|