|
|
@@ -28,6 +28,7 @@ from ...common.reader import ReadImage
|
|
|
from ...models.object_detection.result import DetResult
|
|
|
from ...utils.hpi import HPIConfig
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
+from .._parallel import AutoParallelImageSimpleInferencePipeline
|
|
|
from ..base import BasePipeline
|
|
|
from ..components import CropByBoxes
|
|
|
from ..doc_preprocessor.result import DocPreprocessorResult
|
|
|
@@ -43,12 +44,9 @@ if is_dep_available("scikit-learn"):
|
|
|
from sklearn.cluster import KMeans
|
|
|
|
|
|
|
|
|
-@pipeline_requires_extra("ocr")
|
|
|
-class TableRecognitionPipelineV2(BasePipeline):
|
|
|
+class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
"""Table Recognition Pipeline"""
|
|
|
|
|
|
- entities = ["table_recognition_v2"]
|
|
|
-
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: Dict,
|
|
|
@@ -146,7 +144,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
|
|
|
self._crop_by_boxes = CropByBoxes()
|
|
|
|
|
|
- self.batch_sampler = ImageBatchSampler(batch_size=1)
|
|
|
+ self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
|
|
|
self.img_reader = ReadImage(format="BGR")
|
|
|
|
|
|
def get_model_settings(
|
|
|
@@ -192,7 +190,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
self,
|
|
|
model_settings: Dict,
|
|
|
overall_ocr_res: OCRResult,
|
|
|
- layout_det_res: DetResult,
|
|
|
+ layout_det_res: Union[DetResult, List[DetResult]],
|
|
|
) -> bool:
|
|
|
"""
|
|
|
Check if the input parameters are valid based on the initialized models.
|
|
|
@@ -201,7 +199,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
model_settings (Dict): A dictionary containing input parameters.
|
|
|
overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
|
|
|
The overall OCR result with convert_points_to_boxes information.
|
|
|
- layout_det_res (DetResult): The layout detection result.
|
|
|
+ layout_det_res (Union[DetResult, List[DetResult]]): The layout detection result(s).
|
|
|
Returns:
|
|
|
bool: True if all required models are initialized according to input parameters, False otherwise.
|
|
|
"""
|
|
|
@@ -566,122 +564,131 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
# Return the list of recognized texts from each cell.
|
|
|
return texts_list
|
|
|
|
|
|
- def predict_single_table_recognition_res(
|
|
|
+ def _predict(
|
|
|
self,
|
|
|
- image_array: np.ndarray,
|
|
|
- overall_ocr_res: OCRResult,
|
|
|
- table_box: list,
|
|
|
+ image_arrays: List[np.ndarray],
|
|
|
+ overall_ocr_results: List[OCRResult],
|
|
|
+ table_boxes: List[list],
|
|
|
use_table_cells_ocr_results: bool = False,
|
|
|
use_e2e_wired_table_rec_model: bool = False,
|
|
|
use_e2e_wireless_table_rec_model: bool = False,
|
|
|
flag_find_nei_text: bool = True,
|
|
|
- ) -> SingleTableRecognitionResult:
|
|
|
+ ) -> List[SingleTableRecognitionResult]:
|
|
|
"""
|
|
|
- Predict table recognition results from an image array, layout detection results, and OCR results.
|
|
|
+ Predict table recognition results from image arrays, layout detection results, and OCR results.
|
|
|
|
|
|
Args:
|
|
|
- image_array (np.ndarray): The input image represented as a numpy array.
|
|
|
- overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
|
|
|
- The overall OCR results containing text recognition information.
|
|
|
- table_box (list): The table box coordinates.
|
|
|
+ image_arrays (List[np.ndarray]): The input image arrays.
|
|
|
+ overall_ocr_results (List[OCRResult]): Overall OCR results obtained after running the OCR pipeline.
|
|
|
+ The overall OCR results contain text recognition information.
|
|
|
+ table_boxes (List[list]): The table box coordinates.
|
|
|
use_table_cells_ocr_results (bool): whether to use OCR results with cells.
|
|
|
use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
|
|
|
use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
|
|
|
flag_find_nei_text (bool): Whether to find neighboring text.
|
|
|
Returns:
|
|
|
- SingleTableRecognitionResult: single table recognition result.
|
|
|
+ List[SingleTableRecognitionResult]: Single table recognition results.
|
|
|
"""
|
|
|
+ # TODO: Batch inference
|
|
|
|
|
|
- table_cls_pred = next(self.table_cls_model(image_array))
|
|
|
- table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
- use_e2e_model = False
|
|
|
+ results = []
|
|
|
|
|
|
- if table_cls_result == "wired_table":
|
|
|
- table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
- if use_e2e_wired_table_rec_model == True:
|
|
|
- use_e2e_model = True
|
|
|
- else:
|
|
|
- table_cells_pred = next(
|
|
|
- self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
- elif table_cls_result == "wireless_table":
|
|
|
- table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
- if use_e2e_wireless_table_rec_model == True:
|
|
|
- use_e2e_model = True
|
|
|
- else:
|
|
|
- table_cells_pred = next(
|
|
|
- self.wireless_table_cells_detection_model(
|
|
|
- image_array, threshold=0.3
|
|
|
+ for image_array, overall_ocr_res, table_box in zip(
|
|
|
+ image_arrays, overall_ocr_results, table_boxes
|
|
|
+ ):
|
|
|
+ table_cls_pred = next(self.table_cls_model(image_array))
|
|
|
+ table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
+ use_e2e_model = False
|
|
|
+
|
|
|
+ if table_cls_result == "wired_table":
|
|
|
+ table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
+ if use_e2e_wired_table_rec_model == True:
|
|
|
+ use_e2e_model = True
|
|
|
+ else:
|
|
|
+ table_cells_pred = next(
|
|
|
+ self.wired_table_cells_detection_model(
|
|
|
+ image_array, threshold=0.3
|
|
|
+ )
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+ elif table_cls_result == "wireless_table":
|
|
|
+ table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
+ if use_e2e_wireless_table_rec_model == True:
|
|
|
+ use_e2e_model = True
|
|
|
+ else:
|
|
|
+ table_cells_pred = next(
|
|
|
+ self.wireless_table_cells_detection_model(
|
|
|
+ image_array, threshold=0.3
|
|
|
+ )
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+ if use_e2e_model == False:
|
|
|
+ table_structure_result = self.extract_results(
|
|
|
+ table_structure_pred, "table_stru"
|
|
|
+ )
|
|
|
+ table_cells_result, table_cells_score = self.extract_results(
|
|
|
+ table_cells_pred, "det"
|
|
|
+ )
|
|
|
+ table_cells_result, table_cells_score = self.cells_det_results_nms(
|
|
|
+ table_cells_result, table_cells_score
|
|
|
+ )
|
|
|
+ ocr_det_boxes = self.get_region_ocr_det_boxes(
|
|
|
+ overall_ocr_res["rec_boxes"].tolist(), table_box
|
|
|
+ )
|
|
|
+ table_cells_result = self.cells_det_results_reprocessing(
|
|
|
+ table_cells_result,
|
|
|
+ table_cells_score,
|
|
|
+ ocr_det_boxes,
|
|
|
+ len(table_structure_pred["bbox"]),
|
|
|
+ )
|
|
|
+ if use_table_cells_ocr_results == True:
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
+ image_array, table_cells_result
|
|
|
)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
-
|
|
|
- if use_e2e_model == False:
|
|
|
- table_structure_result = self.extract_results(
|
|
|
- table_structure_pred, "table_stru"
|
|
|
- )
|
|
|
- table_cells_result, table_cells_score = self.extract_results(
|
|
|
- table_cells_pred, "det"
|
|
|
- )
|
|
|
- table_cells_result, table_cells_score = self.cells_det_results_nms(
|
|
|
- table_cells_result, table_cells_score
|
|
|
- )
|
|
|
- ocr_det_boxes = self.get_region_ocr_det_boxes(
|
|
|
- overall_ocr_res["rec_boxes"].tolist(), table_box
|
|
|
- )
|
|
|
- table_cells_result = self.cells_det_results_reprocessing(
|
|
|
- table_cells_result,
|
|
|
- table_cells_score,
|
|
|
- ocr_det_boxes,
|
|
|
- len(table_structure_pred["bbox"]),
|
|
|
- )
|
|
|
- if use_table_cells_ocr_results == True:
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
- image_array, table_cells_result
|
|
|
+ else:
|
|
|
+ cells_texts_list = []
|
|
|
+ single_table_recognition_res = get_table_recognition_res(
|
|
|
+ table_box,
|
|
|
+ table_structure_result,
|
|
|
+ table_cells_result,
|
|
|
+ overall_ocr_res,
|
|
|
+ cells_texts_list,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
)
|
|
|
else:
|
|
|
- cells_texts_list = []
|
|
|
- single_table_recognition_res = get_table_recognition_res(
|
|
|
- table_box,
|
|
|
- table_structure_result,
|
|
|
- table_cells_result,
|
|
|
- overall_ocr_res,
|
|
|
- cells_texts_list,
|
|
|
- use_table_cells_ocr_results,
|
|
|
- )
|
|
|
- else:
|
|
|
- if use_table_cells_ocr_results == True:
|
|
|
- table_cells_result_e2e = list(
|
|
|
- map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
|
|
|
+ if use_table_cells_ocr_results == True:
|
|
|
+ table_cells_result_e2e = list(
|
|
|
+ map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
|
|
|
+ )
|
|
|
+ table_cells_result_e2e = [
|
|
|
+ [rect[0], rect[1], rect[4], rect[5]]
|
|
|
+ for rect in table_cells_result_e2e
|
|
|
+ ]
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
+ image_array, table_cells_result_e2e
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cells_texts_list = []
|
|
|
+ single_table_recognition_res = get_table_recognition_res_e2e(
|
|
|
+ table_box,
|
|
|
+ table_structure_pred,
|
|
|
+ overall_ocr_res,
|
|
|
+ cells_texts_list,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
)
|
|
|
- table_cells_result_e2e = [
|
|
|
- [rect[0], rect[1], rect[4], rect[5]]
|
|
|
- for rect in table_cells_result_e2e
|
|
|
- ]
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
- image_array, table_cells_result_e2e
|
|
|
+
|
|
|
+ neighbor_text = ""
|
|
|
+ if flag_find_nei_text:
|
|
|
+ match_idx_list = get_neighbor_boxes_idx(
|
|
|
+ overall_ocr_res["rec_boxes"], table_box
|
|
|
)
|
|
|
- else:
|
|
|
- cells_texts_list = []
|
|
|
- single_table_recognition_res = get_table_recognition_res_e2e(
|
|
|
- table_box,
|
|
|
- table_structure_pred,
|
|
|
- overall_ocr_res,
|
|
|
- cells_texts_list,
|
|
|
- use_table_cells_ocr_results,
|
|
|
- )
|
|
|
+ if len(match_idx_list) > 0:
|
|
|
+ for idx in match_idx_list:
|
|
|
+ neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
|
|
|
+ single_table_recognition_res["neighbor_texts"] = neighbor_text
|
|
|
+ results.append(single_table_recognition_res)
|
|
|
|
|
|
- neighbor_text = ""
|
|
|
- if flag_find_nei_text:
|
|
|
- match_idx_list = get_neighbor_boxes_idx(
|
|
|
- overall_ocr_res["rec_boxes"], table_box
|
|
|
- )
|
|
|
- if len(match_idx_list) > 0:
|
|
|
- for idx in match_idx_list:
|
|
|
- neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
|
|
|
- single_table_recognition_res["neighbor_texts"] = neighbor_text
|
|
|
- return single_table_recognition_res
|
|
|
+ return results
|
|
|
|
|
|
def predict(
|
|
|
self,
|
|
|
@@ -690,8 +697,8 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
use_doc_unwarping: Optional[bool] = None,
|
|
|
use_layout_detection: Optional[bool] = None,
|
|
|
use_ocr_model: Optional[bool] = None,
|
|
|
- overall_ocr_res: Optional[OCRResult] = None,
|
|
|
- layout_det_res: Optional[DetResult] = None,
|
|
|
+ overall_ocr_res: Optional[Union[OCRResult, List[OCRResult]]] = None,
|
|
|
+ layout_det_res: Optional[Union[DetResult, List[DetResult]]] = None,
|
|
|
text_det_limit_side_len: Optional[int] = None,
|
|
|
text_det_limit_type: Optional[str] = None,
|
|
|
text_det_thresh: Optional[float] = None,
|
|
|
@@ -711,9 +718,9 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
use_layout_detection (bool): Whether to use layout detection.
|
|
|
use_doc_orientation_classify (bool): Whether to use document orientation classification.
|
|
|
use_doc_unwarping (bool): Whether to use document unwarping.
|
|
|
- overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
|
|
|
+ overall_ocr_res (Union[OCRResult, List[OCRResult]]): The overall OCR results with convert_points_to_boxes information.
|
|
|
It will be used if it is not None and use_ocr_model is False.
|
|
|
- layout_det_res (DetResult): The layout detection result.
|
|
|
+ layout_det_res (Union[DetResult, List[DetResult]]): The layout detection result(s).
|
|
|
It will be used if it is not None and use_layout_detection is False.
|
|
|
use_table_cells_ocr_results (bool): whether to use OCR results with cells.
|
|
|
use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
|
|
|
@@ -737,26 +744,40 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
):
|
|
|
yield {"error": "the input params for model settings are invalid!"}
|
|
|
|
|
|
- for img_id, batch_data in enumerate(self.batch_sampler(input)):
|
|
|
- image_array = self.img_reader(batch_data.instances)[0]
|
|
|
+ external_overall_ocr_results = overall_ocr_res
|
|
|
+ if external_overall_ocr_results is not None:
|
|
|
+ if not isinstance(external_overall_ocr_results, list):
|
|
|
+ external_overall_ocr_results = [external_overall_ocr_results]
|
|
|
+ external_overall_ocr_results = iter(external_overall_ocr_results)
|
|
|
+
|
|
|
+ external_layout_det_results = layout_det_res
|
|
|
+ if external_layout_det_results is not None:
|
|
|
+ if not isinstance(external_layout_det_results, list):
|
|
|
+ external_layout_det_results = [external_layout_det_results]
|
|
|
+ external_layout_det_results = iter(external_layout_det_results)
|
|
|
+
|
|
|
+ for _, batch_data in enumerate(self.batch_sampler(input)):
|
|
|
+ image_arrays = self.img_reader(batch_data.instances)
|
|
|
|
|
|
if model_settings["use_doc_preprocessor"]:
|
|
|
- doc_preprocessor_res = next(
|
|
|
+ doc_preprocessor_results = list(
|
|
|
self.doc_preprocessor_pipeline(
|
|
|
- image_array,
|
|
|
+ image_arrays,
|
|
|
use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
use_doc_unwarping=use_doc_unwarping,
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
- doc_preprocessor_res = {"output_img": image_array}
|
|
|
+ doc_preprocessor_results = [{"output_img": arr} for arr in image_arrays]
|
|
|
|
|
|
- doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
+ doc_preprocessor_images = [
|
|
|
+ item["output_img"] for item in doc_preprocessor_results
|
|
|
+ ]
|
|
|
|
|
|
if model_settings["use_ocr_model"]:
|
|
|
- overall_ocr_res = next(
|
|
|
+ overall_ocr_results = list(
|
|
|
self.general_ocr_pipeline(
|
|
|
- doc_preprocessor_image,
|
|
|
+ doc_preprocessor_images,
|
|
|
text_det_limit_side_len=text_det_limit_side_len,
|
|
|
text_det_limit_type=text_det_limit_type,
|
|
|
text_det_thresh=text_det_thresh,
|
|
|
@@ -765,60 +786,131 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
)
|
|
|
- elif use_table_cells_ocr_results == True:
|
|
|
- assert self.general_ocr_config_bak != None
|
|
|
- self.general_ocr_pipeline = self.create_pipeline(
|
|
|
- self.general_ocr_config_bak
|
|
|
- )
|
|
|
+ else:
|
|
|
+ overall_ocr_results = []
|
|
|
+ for _ in doc_preprocessor_images:
|
|
|
+ try:
|
|
|
+ overall_ocr_res = next(external_overall_ocr_results)
|
|
|
+ except StopIteration:
|
|
|
+ raise ValueError("No more overall OCR results")
|
|
|
+ overall_ocr_results.append(overall_ocr_res)
|
|
|
+
|
|
|
+ if use_table_cells_ocr_results:
|
|
|
+ # FIXME: This creates a new pipeline on each call.
|
|
|
+ assert self.general_ocr_config_bak is not None
|
|
|
+ self.general_ocr_pipeline = self.create_pipeline(
|
|
|
+ self.general_ocr_config_bak
|
|
|
+ )
|
|
|
|
|
|
- table_res_list = []
|
|
|
- table_region_id = 1
|
|
|
- if not model_settings["use_layout_detection"] and layout_det_res is None:
|
|
|
- layout_det_res = {}
|
|
|
- img_height, img_width = doc_preprocessor_image.shape[:2]
|
|
|
- table_box = [0, 0, img_width - 1, img_height - 1]
|
|
|
- single_table_rec_res = self.predict_single_table_recognition_res(
|
|
|
- doc_preprocessor_image,
|
|
|
- overall_ocr_res,
|
|
|
- table_box,
|
|
|
+ if (
|
|
|
+ not model_settings["use_layout_detection"]
|
|
|
+ and external_layout_det_results is None
|
|
|
+ ):
|
|
|
+ layout_det_results = [{} for _ in doc_preprocessor_images]
|
|
|
+
|
|
|
+ table_boxes = []
|
|
|
+ for img in doc_preprocessor_images:
|
|
|
+ img_height, img_width = img.shape[:2]
|
|
|
+ table_box = [0, 0, img_width - 1, img_height - 1]
|
|
|
+ table_boxes.append(table_box)
|
|
|
+
|
|
|
+ flat_table_results = self._predict(
|
|
|
+ doc_preprocessor_images,
|
|
|
+ overall_ocr_results,
|
|
|
+ table_boxes,
|
|
|
use_table_cells_ocr_results,
|
|
|
use_e2e_wired_table_rec_model,
|
|
|
use_e2e_wireless_table_rec_model,
|
|
|
flag_find_nei_text=False,
|
|
|
)
|
|
|
- single_table_rec_res["table_region_id"] = table_region_id
|
|
|
- table_res_list.append(single_table_rec_res)
|
|
|
- table_region_id += 1
|
|
|
+
|
|
|
+ for table_res in flat_table_results:
|
|
|
+ table_res["table_region_id"] = 1
|
|
|
+ table_results = [[item] for item in flat_table_results]
|
|
|
else:
|
|
|
if model_settings["use_layout_detection"]:
|
|
|
- layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
|
|
|
-
|
|
|
- for box_info in layout_det_res["boxes"]:
|
|
|
- if box_info["label"].lower() in ["table"]:
|
|
|
- crop_img_info = self._crop_by_boxes(image_array, [box_info])
|
|
|
- crop_img_info = crop_img_info[0]
|
|
|
- table_box = crop_img_info["box"]
|
|
|
- single_table_rec_res = (
|
|
|
- self.predict_single_table_recognition_res(
|
|
|
- crop_img_info["img"],
|
|
|
- overall_ocr_res,
|
|
|
- table_box,
|
|
|
- use_table_cells_ocr_results,
|
|
|
- use_e2e_wired_table_rec_model,
|
|
|
- use_e2e_wireless_table_rec_model,
|
|
|
- )
|
|
|
- )
|
|
|
- single_table_rec_res["table_region_id"] = table_region_id
|
|
|
- table_res_list.append(single_table_rec_res)
|
|
|
+ layout_det_results = list(
|
|
|
+ self.layout_det_model(doc_preprocessor_images)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ layout_det_results = []
|
|
|
+ for _ in doc_preprocessor_images:
|
|
|
+ try:
|
|
|
+ layout_det_res = next(external_layout_det_results)
|
|
|
+ except StopIteration:
|
|
|
+ raise ValueError("No more layout det results")
|
|
|
+ layout_det_results.append(layout_det_res)
|
|
|
+
|
|
|
+ cropped_imgs = []
|
|
|
+ table_boxes = []
|
|
|
+ repeated_overall_ocr_results = []
|
|
|
+ chunk_indices = [0]
|
|
|
+ for image_array, layout_det_res, overall_ocr_res in zip(
|
|
|
+ image_arrays, layout_det_results, overall_ocr_results
|
|
|
+ ):
|
|
|
+ for box_info in layout_det_res["boxes"]:
|
|
|
+ if box_info["label"].lower() in ["table"]:
|
|
|
+ crop_img_info = self._crop_by_boxes(image_array, [box_info])
|
|
|
+ crop_img_info = crop_img_info[0]
|
|
|
+ cropped_imgs.append(crop_img_info["img"])
|
|
|
+ table_boxes.append(crop_img_info["box"])
|
|
|
+ repeated_overall_ocr_results.append(overall_ocr_res)
|
|
|
+ chunk_indices.append(len(cropped_imgs))
|
|
|
+
|
|
|
+ flat_table_results = self._predict(
|
|
|
+ cropped_imgs,
|
|
|
+ repeated_overall_ocr_results,
|
|
|
+ table_boxes,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
+ use_e2e_wired_table_rec_model,
|
|
|
+ use_e2e_wireless_table_rec_model,
|
|
|
+ )
|
|
|
+
|
|
|
+ table_results = [
|
|
|
+ flat_table_results[i:j]
|
|
|
+ for i, j in zip(chunk_indices[:-1], chunk_indices[1:])
|
|
|
+ ]
|
|
|
+
|
|
|
+ for table_results_for_img in table_results:
|
|
|
+ table_region_id = 1
|
|
|
+ for table_res in table_results_for_img:
|
|
|
+ table_res["table_region_id"] = table_region_id
|
|
|
table_region_id += 1
|
|
|
|
|
|
- single_img_res = {
|
|
|
- "input_path": batch_data.input_paths[0],
|
|
|
- "page_index": batch_data.page_indexes[0],
|
|
|
- "doc_preprocessor_res": doc_preprocessor_res,
|
|
|
- "layout_det_res": layout_det_res,
|
|
|
- "overall_ocr_res": overall_ocr_res,
|
|
|
- "table_res_list": table_res_list,
|
|
|
- "model_settings": model_settings,
|
|
|
- }
|
|
|
- yield TableRecognitionResult(single_img_res)
|
|
|
+ for (
|
|
|
+ input_path,
|
|
|
+ page_index,
|
|
|
+ doc_preprocessor_res,
|
|
|
+ layout_det_res,
|
|
|
+ overall_ocr_res,
|
|
|
+ table_results_for_img,
|
|
|
+ ) in zip(
|
|
|
+ batch_data.input_paths,
|
|
|
+ batch_data.page_indexes,
|
|
|
+ doc_preprocessor_results,
|
|
|
+ layout_det_results,
|
|
|
+ overall_ocr_results,
|
|
|
+ table_results,
|
|
|
+ ):
|
|
|
+ single_img_res = {
|
|
|
+ "input_path": input_path,
|
|
|
+ "page_index": page_index,
|
|
|
+ "doc_preprocessor_res": doc_preprocessor_res,
|
|
|
+ "layout_det_res": layout_det_res,
|
|
|
+ "overall_ocr_res": overall_ocr_res,
|
|
|
+ "table_res_list": table_results_for_img,
|
|
|
+ "model_settings": model_settings,
|
|
|
+ }
|
|
|
+ yield TableRecognitionResult(single_img_res)
|
|
|
+
|
|
|
+
|
|
|
+@pipeline_requires_extra("ocr")
|
|
|
+class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
|
|
|
+ entities = ["table_recognition_v2"]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _pipeline_cls(self):
|
|
|
+ return _TableRecognitionPipelineV2
|
|
|
+
|
|
|
+ def _get_batch_size(self, config):
|
|
|
+ return config.get("batch_size", 1)
|