zhangyubo0722 1 жил өмнө
parent
commit
860f4bb860

+ 3 - 3
paddlex/inference/components/paddle_predictor/predictor.py

@@ -172,10 +172,10 @@ No need to generate again."
 
 
 class ImagePredictor(BasePaddlePredictor):
-    DEAULT_INPUTS = {"batch_data": "img"}
+    DEAULT_INPUTS = {"img": "img"}
 
-    def to_batch(self, imgs):
-        return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]
+    def to_batch(self, img):
+        return [np.stack(img, axis=0).astype(dtype=np.float32, copy=False)]
 
 
 class ImageDetPredictor(BasePaddlePredictor):

+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -16,3 +16,4 @@ from .image_classification import ClasPipeline
 from .ocr import OCRPipeline
 from .object_detection import DetPipeline
 from .instance_segmentation import InstanceSegPipeline
+from .semantic_segmentation import SegPipeline

+ 33 - 0
paddlex/inference/pipelines/semantic_segmentation.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import create_predictor
+
+
+class SegPipeline(BasePipeline):
+    """Det Pipeline"""
+
+    entities = "semantic_segmentation"
+
+    def __init__(self, model, batch_size=1, device="gpu"):
+        super().__init__()
+        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+
+    def predict(self, x):
+        self._check_input(x)
+        yield from self._predict(x)
+
+    def _check_input(self, x):
+        pass

+ 1 - 0
paddlex/inference/predictors/__init__.py

@@ -22,6 +22,7 @@ from .text_recognition import TextRecPredictor
 from .table_recognition import TablePredictor
 from .object_detection import DetPredictor
 from .instance_segmentation import InstanceSegPredictor
+from .semantic_segmentation import SegPredictor
 from .official_models import official_models
 
 

+ 2 - 2
paddlex/inference/predictors/instance_segmentation.py

@@ -18,7 +18,7 @@ from .object_detection import DetPredictor
 from ...utils.func_register import FuncRegister
 from ...modules.instance_segmentation.model_list import MODELS
 from ..components import *
-from ..results import InstanceSegResults
+from ..results import InstanceSegResult
 from ..utils.process_hook import batchable_method
 
 
@@ -56,4 +56,4 @@ class InstanceSegPredictor(DetPredictor):
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "boxes", "masks", "labels"]
-        return {"result": InstanceSegResults({key: data[key] for key in keys})}
+        return {"result": InstanceSegResult({key: data[key] for key in keys})}

+ 2 - 2
paddlex/inference/predictors/object_detection.py

@@ -17,7 +17,7 @@ import numpy as np
 from ...utils.func_register import FuncRegister
 from ...modules.object_detection.model_list import MODELS
 from ..components import *
-from ..results import DetResults
+from ..results import DetResult
 from ..utils.process_hook import batchable_method
 from .base import BasicPredictor
 
@@ -94,4 +94,4 @@ class DetPredictor(BasicPredictor):
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "boxes", "labels"]
-        return {"result": DetResults({key: data[key] for key in keys})}
+        return {"result": DetResult({key: data[key] for key in keys})}

+ 96 - 0
paddlex/inference/predictors/semantic_segmentation.py

@@ -0,0 +1,96 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...utils.func_register import FuncRegister
+from ...modules.semantic_segmentation.model_list import MODELS
+from ..components import *
+from ..results import SegResult
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class SegPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def _check_args(self, kwargs):
+        assert set(kwargs.keys()).issubset(set(["batch_size"]))
+        return kwargs
+
+    def _build_components(self):
+        ops = {}
+        ops["ReadImage"] = ReadImage(
+            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
+        )
+        ops["ToCHWImage"] = ToCHWImage()
+        for cfg in self.config["Deploy"]["transforms"]:
+            tf_key = cfg["type"]
+            func = self._FUNC_MAP.get(tf_key)
+            cfg.pop("type")
+            args = cfg
+            op = func(self, **args) if args else func(self)
+            ops[tf_key] = op
+
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        ops["predictor"] = predictor
+        return ops
+
+    @register("Resize")
+    def build_resize(
+        self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
+    ):
+        assert target_size
+        op = Resize(
+            target_size=target_size,
+            keep_ratio=keep_ratio,
+            size_divisor=size_divisor,
+            interp=interp,
+        )
+        return op
+
+    @register("ResizeByLong")
+    def build_resizebylong(self, long_size):
+        assert long_size
+        return ResizeByLong(
+            target_long_edge=long_size, size_divisor=size_divisor, interp=interp
+        )
+
+    @register("ResizeByShort")
+    def build_resizebylong(self, short_size):
+        assert short_size
+        return ResizeByLong(
+            target_long_edge=short_size, size_divisor=size_divisor, interp=interp
+        )
+
+    @register("Normalize")
+    def build_normalize(
+        self,
+        mean=0.5,
+        std=0.5,
+    ):
+        return Normalize(mean=mean, std=std)
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "pred"]
+        return {"result": SegResult({key: data[key] for key in keys})}

+ 3 - 2
paddlex/inference/results/__init__.py

@@ -17,5 +17,6 @@ from .text_det import TextDetResult
 from .text_rec import TextRecResult
 from .table_rec import TableRecResult
 from .ocr import OCRResult
-from .det import DetResults
-from .instance_seg import InstanceSegResults
+from .det import DetResult
+from .seg import SegResult
+from .instance_seg import InstanceSegResult

+ 1 - 1
paddlex/inference/results/det.py

@@ -81,7 +81,7 @@ def draw_box(img, np_boxes, labels):
     return img
 
 
-class DetResults(BaseResult):
+class DetResult(BaseResult):
     """Save Result Transform"""
 
     def __init__(self, data):

+ 1 - 1
paddlex/inference/results/instance_seg.py

@@ -60,7 +60,7 @@ def draw_mask(im, np_boxes, np_masks, labels):
     return Image.fromarray(im.astype("uint8"))
 
 
-class InstanceSegResults(BaseResult):
+class InstanceSegResult(BaseResult):
     """Save Result Transform"""
 
     def __init__(self, data):

+ 65 - 0
paddlex/inference/results/seg.py

@@ -0,0 +1,65 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from .base import BaseResult
+
+
+class SegResult(BaseResult):
+    """Save Result Transform"""
+
+    def __init__(self, data):
+        super().__init__(data)
+        self.data = data
+        # We use pillow backend to save both numpy arrays and PIL Image objects
+        self._img_writer.set_backend("pillow", format_="PNG")
+
+    def _get_res_img(self):
+        """apply"""
+        seg_map = self.data["pred"]
+        pc_map = self.get_pseudo_color_map(seg_map[0])
+        return pc_map
+
+    def get_pseudo_color_map(self, pred):
+        """get_pseudo_color_map"""
+        if pred.min() < 0 or pred.max() > 255:
+            raise ValueError("`pred` cannot be cast to uint8.")
+        pred = pred.astype(np.uint8)
+        pred_mask = Image.fromarray(pred, mode="P")
+        color_map = self._get_color_map_list(256)
+        pred_mask.putpalette(color_map)
+        return pred_mask
+
+    @staticmethod
+    def _get_color_map_list(num_classes, custom_color=None):
+        """_get_color_map_list"""
+        num_classes += 1
+        color_map = num_classes * [0, 0, 0]
+        for i in range(0, num_classes):
+            j = 0
+            lab = i
+            while lab:
+                color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
+                color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
+                color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
+                j += 1
+                lab >>= 3
+        color_map = color_map[3:]
+
+        if custom_color:
+            color_map[: len(custom_color)] = custom_color
+        return color_map