Browse Source

add training code

warrentdrew 10 months ago
parent
commit
12c154c096

+ 38 - 0
paddlex/configs/modules/3d_bev_detection/BEVFusion.yaml

@@ -0,0 +1,38 @@
+Global:
+  model: BEVFusion
+  mode: check_dataset
+  dataset_dir: "/paddle/dataset/paddlex/3d/nuscenes_demo"
+  device: gpu:0,1,2,3
+  output: "output"
+  load_cam_from: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/BEVFusion_camera_pretrained.pdparams"
+  load_lidar_from: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/BEVFusion_lidar_pretrained.pdparams"
+  datart_prefix: True
+  version: "mini"
+
+CheckDataset:
+  convert:
+    enable: False
+  split:
+    enable: False
+
+Train:
+  epochs_iters: 2
+  batch_size: 2
+  learning_rate: 0.001
+  warmup_steps: 150
+
+Evaluate:
+  batch_size: 1
+  weight_path: output/train/best_model/model.pdparams
+
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/BEVFusion_pretrained.pdparams"
+
+
+Predict:
+  batch_size: 1
+  model_dir: "output_bevfusion"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/det_3d/demo_det_3d/nuscenes_infos_val.pkl"
+  kernel_option:
+    run_mode: paddle

+ 18 - 0
paddlex/modules/3d_bev_detection/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .trainer import BEVFusionTrainer
+from .dataset_checker import BEVFusionDatasetChecker
+from .evaluator import BEVFusionEvaluator
+from .exportor import BEVFusionExportor

+ 95 - 0
paddlex/modules/3d_bev_detection/dataset_checker/__init__.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import pickle
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, deep_analyse
+from ..model_list import MODELS
+
+
+class BEVFusionDatasetChecker(BaseDatasetChecker):
+    entities = MODELS
+
+    def check_dataset(self, dataset_dir: str) -> dict:
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+            sample_num (int): the number to be sampled.
+        Returns:
+            dict: dataset summary.
+        """
+        return check(dataset_dir)
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        return deep_analyse(dataset_dir, self.output)
+
+    def get_data(self, ann_file, max_sample_num):
+        infos = self.data_infos(ann_file, max_sample_num)
+        meta = []
+        for info in infos:
+            image_paths = []
+            cam_orders = [
+                "CAM_FRONT_LEFT",
+                "CAM_FRONT",
+                "CAM_FRONT_RIGHT",
+                "CAM_BACK_RIGHT",
+                "CAM_BACK",
+                "CAM_BACK_LEFT",
+            ]
+            for cam_type in cam_orders:
+                cam_info = info["cams"][cam_type]
+                cam_data_path = cam_info["data_path"]
+                image_paths.append(cam_data_path)
+
+            meta.append(
+                {
+                    "sample_idx": info["token"],
+                    "lidar_path": info["lidar_path"],
+                    "image_paths": image_paths,
+                }
+            )
+        return meta
+
+    def data_infos(self, ann_file, max_sample_num):
+        data = pickle.load(open(ann_file, "rb"))
+        data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+        data_infos = data_infos[:max_sample_num]
+        return data_infos
+
+    def get_show_type(self) -> str:
+        """get the show type of dataset
+
+        Returns:
+            str: show type
+        """
+        return "txt"
+
+    def get_dataset_type(self) -> str:
+        """return the dataset type
+
+        Returns:
+            str: dataset type
+        """
+        return "NuscenesMMDataset"

+ 17 - 0
paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .analyse_dataset import deep_analyse

+ 107 - 0
paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/analyse_dataset.py

@@ -0,0 +1,107 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import json
+import platform
+import pickle
+
+from collections import defaultdict
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from nuscenes.nuscenes import NuScenes
+
+from paddlex.utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def deep_analyse(dataset_dir, output):
+    """class analysis for dataset"""
+    tags = ["train", "val"]
+    all_instances = 0
+    class_name_train = defaultdict(int)
+    class_name_val = defaultdict(int)
+    for tag in tags:
+        anno_file = os.path.join(dataset_dir, f"nuscenes_infos_{tag}.pkl")
+        with open(anno_file, "rb") as f:
+            datas = pickle.load(f)
+        data_infos = datas["infos"]
+        for item in data_infos:
+            gts = item["gt_names"]
+            for gt_name in gts:
+                if tag == "train":
+                    class_name_train[gt_name] = (
+                        0
+                        if gt_name not in class_name_train
+                        else class_name_train[gt_name] + 1
+                    )
+                elif tag == "val":
+                    class_name_val[gt_name] = (
+                        0
+                        if gt_name not in class_name_val
+                        else class_name_val[gt_name] + 1
+                    )
+
+    classes = set()
+    for key in class_name_train:
+        classes.add(key)
+    for key in class_name_val:
+        classes.add(key)
+
+    # set cnt to 0 if class not in cnt dict
+    for key in classes:
+        if key not in class_name_train:
+            class_name_train[key] = 0
+        if key not in class_name_val:
+            class_name_val[key] = 0
+
+    cnts_train = [cat_ids for cat_name, cat_ids in class_name_train.items()]
+    cnts_val = [cat_ids for cat_name, cat_ids in class_name_val.items()]
+
+    # sort class name
+    classes = [cat_name for cat_name, cat_ids in class_name_train.items()]
+    sorted_id = sorted(
+        range(len(cnts_train)), key=lambda k: cnts_train[k], reverse=True
+    )
+    cnts_train_sorted = [cnts_train[index] for index in sorted_id]
+    cnts_val_sorted = [cnts_val[index] for index in sorted_id]
+    classes_sorted = [classes[index] for index in sorted_id]
+
+    x = np.arange(len(classes))
+    width = 0.5
+
+    # bar
+    os_system = platform.system().lower()
+    if os_system == "windows":
+        plt.rcParams["font.sans-serif"] = "FangSong"
+    else:
+        font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH)
+    fig, ax = plt.subplots(figsize=(max(8, int(len(classes) / 5)), 5), dpi=120)
+    ax.bar(x, cnts_train_sorted, width=0.5, label="train")
+    ax.bar(x + width, cnts_val_sorted, width=0.5, label="val")
+    plt.xticks(
+        x + width / 2,
+        classes_sorted,
+        rotation=90,
+        fontproperties=None if os_system == "windows" else font,
+    )
+    ax.set_ylabel("Counts")
+    plt.legend()
+    fig.tight_layout()
+    fig_path = os.path.join(output, "histogram.png")
+    fig.savefig(fig_path)
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 102 - 0
paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/check_dataset.py

@@ -0,0 +1,102 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import pickle
+
+from .....utils.errors import DatasetFileNotFoundError
+from .....utils.misc import abspath
+
+
+def check(dataset_dir):
+    dataset_dir = abspath(dataset_dir)
+    max_sample_num = 5
+
+    if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+        raise DatasetFileNotFoundError(file_path=dataset_dir)
+
+    anno_file = osp.join(dataset_dir, "nuscenes_infos_train.pkl")
+    if not osp.exists(anno_file):
+        raise DatasetFileNotFoundError(file_path=anno_file)
+    train_mate, train_classes = get_data(anno_file, max_sample_num)
+
+    anno_file = osp.join(dataset_dir, "nuscenes_infos_val.pkl")
+    if not osp.exists(anno_file):
+        raise DatasetFileNotFoundError(file_path=anno_file)
+    val_mate, val_classes = get_data(anno_file, max_sample_num)
+    train_sample_paths = []
+    val_sample_paths = []
+
+    for item in train_mate:
+        train_sample_paths.append(item["lidar_path"])
+
+    for item in val_mate:
+        val_sample_paths.append(item["lidar_path"])
+
+    all_classes = set()
+    for tc in train_classes:
+        all_classes.add(tc)
+    for vc in val_classes:
+        all_classes.add(vc)
+    num_classes = len(all_classes)
+    meta = {
+        "num_classes": num_classes,
+        "train_meta": train_mate,
+        "val_meta": val_mate,
+        "train_sample_paths": train_sample_paths,
+        "val_sample_paths": val_sample_paths,
+    }
+    return meta
+
+
+def get_data(ann_file, max_sample_num):
+    infos = data_infos(ann_file, max_sample_num)
+    meta = []
+    gt_class = set()
+    for info in infos:
+        image_paths = []
+        cam_orders = [
+            "CAM_FRONT_LEFT",
+            "CAM_FRONT",
+            "CAM_FRONT_RIGHT",
+            "CAM_BACK_RIGHT",
+            "CAM_BACK",
+            "CAM_BACK_LEFT",
+        ]
+        for cam_type in cam_orders:
+            cam_info = info["cams"][cam_type]
+            cam_data_path = cam_info["data_path"]
+            image_paths.append(cam_data_path)
+
+        meta.append(
+            {
+                "sample_idx": info["token"],
+                "lidar_path": info["lidar_path"],
+                "image_paths": image_paths,
+            }
+        )
+        class_names = info["gt_names"]
+
+        for cls_name in class_names:
+            gt_class.add(cls_name)
+
+    return meta, gt_class
+
+
+def data_infos(ann_file, max_sample_num):
+    data = pickle.load(open(ann_file, "rb"))
+    data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+    data_infos = data_infos[:max_sample_num]
+    return data_infos

+ 46 - 0
paddlex/modules/3d_bev_detection/evaluator.py

@@ -0,0 +1,46 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ..base import BaseEvaluator
+from .model_list import MODELS
+
+
+class BEVFusionEvaluator(BaseEvaluator):
+    """Object Detection Model Evaluator"""
+
+    entities = MODELS
+
+    def update_config(self):
+        """update evalution config"""
+        if self.eval_config.batch_size is not None:
+            self.pdx_config.update_batch_size(self.eval_config.batch_size)
+        self.pdx_config.update_dataset(
+            self.global_config.dataset_dir,
+            self.global_config.datart_prefix,
+            "NuscenesMMDataset",
+            version=self.global_config.version,
+        )
+        self.pdx_config.update_weights(self.eval_config.weight_path)
+
+    def get_eval_kwargs(self) -> dict:
+        """get key-value arguments of model evalution function
+
+        Returns:
+            dict: the arguments of evaluation function.
+        """
+        return {
+            "weight_path": self.eval_config.weight_path,
+            "device": self.get_device(using_device_number=1),
+        }

+ 22 - 0
paddlex/modules/3d_bev_detection/exportor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class BEVFusionExportor(BaseExportor):
+    """3D BEV Detection Model Exportor"""
+
+    entities = MODELS

+ 18 - 0
paddlex/modules/3d_bev_detection/model_list.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+MODELS = [
+    "BEVFusion",
+]

+ 70 - 0
paddlex/modules/3d_bev_detection/trainer.py

@@ -0,0 +1,70 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+
+from ..base import BaseTrainer
+from ...utils.config import AttrDict
+from ...utils import logging
+from .model_list import MODELS
+
+
+class BEVFusionTrainer(BaseTrainer):
+    """3D BEV Detection Model Trainer"""
+
+    entities = MODELS
+
+    def _update_dataset(self):
+        """update dataset settings"""
+        self.pdx_config.update_dataset(
+            self.global_config.dataset_dir,
+            self.global_config.get("datart_prefix", True),
+            "NuscenesMMDataset",
+            version=self.global_config.get("version", "mini"),
+        )
+
+    def _update_pretrained_model(self):
+        self.pdx_config.update_pretrained_model(
+            self.global_config.load_cam_from, self.global_config.load_lidar_from
+        )
+
+    def update_config(self):
+        """update training config"""
+        self._update_dataset()
+        self._update_pretrained_model()
+
+        if self.train_config.batch_size is not None:
+            self.pdx_config.update_batch_size(self.train_config.batch_size)
+        if self.train_config.learning_rate is not None:
+            self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+        if self.train_config.epochs_iters is not None:
+            self.pdx_config.update_epochs(self.train_config.epochs_iters)
+            epochs_iters = self.train_config.epochs_iters
+        else:
+            epochs_iters = self.pdx_config.get_epochs_iters()
+        if self.global_config.output is not None:
+            self.pdx_config.update_save_dir(self.global_config.output)
+
+    def get_train_kwargs(self) -> dict:
+        """get key-value arguments of model training function
+
+        Returns:
+            dict: the arguments of training function.
+        """
+        train_args = {"device": self.get_device()}
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
+        if self.global_config.output is not None:
+            train_args["save_dir"] = self.global_config.output
+        return train_args

+ 7 - 2
paddlex/modules/__init__.py

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-
+from importlib import import_module
 
 
 from .base import (
 from .base import (
     build_dataset_checker,
     build_dataset_checker,
@@ -104,13 +104,18 @@ from .face_recognition import (
 
 
 from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator
 from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator
 
 
+module_3d_bev_detection = import_module(".3d_bev_detection", "paddlex.modules")
+BEVFusionDatasetChecker = getattr(module_3d_bev_detection, "BEVFusionDatasetChecker")
+BEVFusionTrainer = getattr(module_3d_bev_detection, "BEVFusionTrainer")
+BEVFusionEvaluator = getattr(module_3d_bev_detection, "BEVFusionEvaluator")
+BEVFusionExportor = getattr(module_3d_bev_detection, "BEVFusionExportor")
+
 from .keypoint_detection import (
 from .keypoint_detection import (
     KeypointDatasetChecker,
     KeypointDatasetChecker,
     KeypointTrainer,
     KeypointTrainer,
     KeypointEvaluator,
     KeypointEvaluator,
     KeypointExportor,
     KeypointExportor,
 )
 )
-
 from .video_classification import (
 from .video_classification import (
     VideoClsDatasetChecker,
     VideoClsDatasetChecker,
     VideoClsTrainer,
     VideoClsTrainer,

+ 17 - 0
paddlex/repo_apis/Paddle3D_api/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Register models and suites
+# BEVFusion 3D object detection
+from .bev_fusion import BEVFusionModel, BEVFusionRunner, register

+ 18 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from . import register
+from .model import BEVFusionModel
+from .runner import BEVFusionRunner

+ 118 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py

@@ -0,0 +1,118 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ....utils.misc import abspath
+from ..pp3d_config import PP3DConfig
+
+
+class BEVFusionConfig(PP3DConfig):
+    def update_dataset(
+        self, dataset_dir, datart_prefix=True, dataset_type=None, *, version=None
+    ):
+        dataset_dir = abspath(dataset_dir)
+        if dataset_type is None:
+            dataset_type = "NuscenesMMDataset"
+        if dataset_type == "NuscenesMMDataset":
+            ds_cfg = self._make_nuscenes_mm_dataset_config(
+                dataset_dir, datart_prefix, version=version
+            )
+        else:
+            raise ValueError(f"{dataset_type} is not supported.")
+        # Prune old config
+        keys_to_keep = ("transforms", "mode", "class_names", "modality")
+        if "train_dataset" in self:
+            for key in list(k for k in self.train_dataset if k not in keys_to_keep):
+                self.train_dataset.pop(key)
+        if "val_dataset" in self:
+            for key in list(k for k in self.val_dataset if k not in keys_to_keep):
+                self.val_dataset.pop(key)
+        self.update(ds_cfg)
+
+    def _make_nuscenes_mm_dataset_config(
+        self, dataset_root_path, datart_prefix, version
+    ):
+        if version is None:
+            # Default version
+            version = "trainval"
+        if version == "trainval":
+            train_mode = "train"
+            val_mode = "val"
+        elif version == "mini":
+            train_mode = "mini_train"
+            val_mode = "mini_val"
+        else:
+            raise ValueError("Unsupported version.")
+        return {
+            "train_dataset": {
+                "type": "NuscenesMMDataset",
+                "data_root": dataset_root_path,
+                "ann_file": f"{dataset_root_path}/nuscenes_infos_train.pkl",
+                "mode": train_mode,
+                "datart_prefix": datart_prefix,
+            },
+            "val_dataset": {
+                "type": "NuscenesMMDataset",
+                "data_root": dataset_root_path,
+                "ann_file": f"{dataset_root_path}/nuscenes_infos_val.pkl",
+                "mode": val_mode,
+                "datart_prefix": datart_prefix,
+            },
+        }
+
+    def _update_amp(self, amp):
+        # XXX: Currently, we hard-code the AMP settings according to
+        # https://github.com/PaddlePaddle/Paddle3D/blob/3cf884ecbc94330be0e2db780434bb60b9b4fe8c/configs/smoke/smoke_dla34_no_dcn_kitti_amp.yml#L6
+        amp_cfg = {
+            "amp_cfg": {
+                "use_amp": False,
+                "enable": False,
+                "level": amp,
+                "scaler": {"init_loss_scaling": 512.0},
+                "custom_black_list": ["matmul_v2", "elementwise_mul"],
+            }
+        }
+        self.update(amp_cfg)
+
+    def update_class_names(self, class_names):
+        if "train_dataset" in self and "transforms" in getattr(self, "train_dataset"):
+            self.train_dataset["class_names"] = class_names
+            # TODO: Provide another method to customize `SampleNameFilter` classes names
+            # TODO: Give an explicit warning for the implicit behavior
+            tf_cfg_list = self.train_dataset["transforms"]
+            for tf_cfg in tf_cfg_list:
+                if tf_cfg["type"] == "SampleNameFilter":
+                    tf_cfg["classes"] = class_names
+                    # We assume that there is at most one `SampleNameFilter` in `tf_cfg_list`
+                    break
+        if "val_dataset" in self:
+            self.val_dataset["class_names"] = class_names
+
+    def update_pretrained_model(self, load_cam_from: str, load_lidar_from: str):
+        """update model pretrained weight
+
+        Args:
+            load_cam_from (str): the path to cam weight file of model.
+            load_lidar_from (str): the path to lidar weight file of model.
+        """
+        self.model["load_cam_from"] = load_cam_from
+        self.model["load_lidar_from"] = load_lidar_from
+
+    def update_weights(self, weight_path: str):
+        """update model weight
+
+        Args:
+            weight_path (str): the path to weight file of model.
+        """
+        self["weights"] = weight_path

+ 227 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py

@@ -0,0 +1,227 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os.path as osp
+
+from .runner import raise_unsupported_api_error
+from ...base import BaseModel
+from ....utils import logging
+from ...base.utils.arg import CLIArgument
+from ....utils.misc import abspath
+
+
+class BEVFusionModel(BaseModel):
+
+    def train(
+        self,
+        batch_size=None,
+        learning_rate=None,
+        epochs_iters=None,
+        pretrained=None,
+        ips=None,
+        device="gpu",
+        resume_path=None,
+        dy2st=False,
+        amp="OFF",
+        num_workers=None,
+        use_vdl=True,
+        save_dir=None,
+        **kwargs,
+    ):
+        if resume_path is not None:
+            resume_path = abspath(resume_path)
+        if not use_vdl:
+            logging.warning("Currently, VisualDL cannot be disabled during training.")
+        if save_dir is not None:
+            save_dir = abspath(save_dir)
+        else:
+            # `save_dir` is None
+            save_dir = abspath(osp.join("output", "train"))
+
+        if dy2st:
+            raise ValueError(f"`dy2st`={dy2st} is not supported.")
+        if device in ("cpu", "gpu"):
+            logging.warning(
+                f"The device type to use will be automatically determined, which may differ from the sepcified type: {repr(device)}."
+            )
+
+        # Update YAML config file
+        config = self.config.copy()
+        if epochs_iters is not None:
+            config.update_iters(epochs_iters)
+        if amp is not None:
+            if amp != "OFF":
+                config._update_amp(amp)
+
+        # Parse CLI arguments
+        cli_args = []
+        if batch_size is not None:
+            cli_args.append(CLIArgument("--batch_size", batch_size))
+        if learning_rate is not None:
+            cli_args.append(CLIArgument("--learning_rate", learning_rate))
+        if num_workers is not None:
+            cli_args.append(CLIArgument("--num_workers", num_workers))
+        if resume_path is not None:
+            if save_dir is not None:
+                raise ValueError(
+                    "When `resume_path` is not None, `save_dir` must be set to None."
+                )
+            model_dir = osp.dirname(resume_path)
+            cli_args.append(CLIArgument("--resume"))
+            cli_args.append(CLIArgument("--save_dir", model_dir))
+        if save_dir is not None:
+            cli_args.append(CLIArgument("--save_dir", save_dir))
+        if pretrained is not None:
+            cli_args.append(CLIArgument("--model", abspath(pretrained)))
+
+        do_eval = kwargs.pop("do_eval", True)
+
+        profile = kwargs.pop("profile", None)
+        if profile is not None:
+            cli_args.append(CLIArgument("--profiler_options", profile))
+
+        log_interval = kwargs.pop("log_interval", 1)
+        if log_interval is not None:
+            cli_args.append(CLIArgument("--log_interval", log_interval))
+
+        save_interval = kwargs.pop("save_interval", 1)
+        if save_interval is not None:
+            cli_args.append(CLIArgument("--save_interval", save_interval))
+
+        seed = kwargs.pop("seed", None)
+        if seed is not None:
+            cli_args.append(CLIArgument("--seed", seed))
+
+        self._assert_empty_kwargs(kwargs)
+
+        # PDX related settings
+        uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
+        config.update({"uniform_output_enabled": uniform_output_enabled})
+        config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            return self.runner.train(
+                config_path, cli_args, device, ips, save_dir, do_eval=do_eval
+            )
+
+    def evaluate(
+        self,
+        weight_path,
+        batch_size=None,
+        ips=None,
+        device="gpu",
+        amp="OFF",
+        num_workers=None,
+        **kwargs,
+    ):
+        weight_path = abspath(weight_path)
+
+        if device in ("cpu", "gpu"):
+            logging.warning(
+                f"The device type to use will be automatically determined, which may differ from the sepcified type: {repr(device)}."
+            )
+
+        # Update YAML config file
+        config = self.config.copy()
+
+        if amp is not None:
+            if amp != "OFF":
+                raise ValueError("AMP evaluation is not supported.")
+
+        # Parse CLI arguments
+        cli_args = []
+        if weight_path is not None:
+            cli_args.append(CLIArgument("--model", weight_path))
+        if batch_size is not None:
+            cli_args.append(CLIArgument("--batch_size", batch_size))
+            if batch_size != 1:
+                raise ValueError("Batch size other than 1 is not supported.")
+        if num_workers is not None:
+            cli_args.append(CLIArgument("--num_workers", num_workers))
+
+        self._assert_empty_kwargs(kwargs)
+
+        # PDX related settings
+        uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
+        config.update({"uniform_output_enabled": uniform_output_enabled})
+        config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            cp = self.runner.evaluate(config_path, cli_args, device, ips)
+            return cp
+
+    def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
+        raise_unsupported_api_error("predict", self.__class__)
+
+    def export(self, weight_path, save_dir, **kwargs):
+        if not weight_path.startswith("http"):
+            weight_path = abspath(weight_path)
+        save_dir = abspath(save_dir)
+
+        # Update YAML config file
+        config = self.config.copy()
+
+        # Parse CLI arguments
+        cli_args = []
+        if weight_path is not None:
+            cli_args.append(CLIArgument("--model", weight_path))
+        if save_dir is not None:
+            cli_args.append(CLIArgument("--save_dir", save_dir))
+
+        self._assert_empty_kwargs(kwargs)
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            return self.runner.export(config_path, cli_args, None)
+
+    def infer(self, model_dir, device="gpu", **kwargs):
+        model_dir = abspath(model_dir)
+
+        # Parse CLI arguments
+        cli_args = []
+        model_file_path = osp.join(model_dir, ".pdmodel")
+        params_file_path = osp.join(model_dir, ".pdiparams")
+        cli_args.append(CLIArgument("--model_file", model_file_path))
+        cli_args.append(CLIArgument("--params_file", params_file_path))
+        if device is not None:
+            device_type, _ = self.runner.parse_device(device)
+            if device_type not in ("cpu", "gpu"):
+                raise ValueError(f"`device`={repr(device)} is not supported.")
+        infer_dir = osp.join(self.runner.runner_root_path, self.model_info["infer_dir"])
+        self._assert_empty_kwargs(kwargs)
+        # The inference script does not require a config file
+        return self.runner.infer(None, cli_args, device, infer_dir, None)
+
+    def compression(
+        self,
+        weight_path,
+        ann_file=None,
+        class_names=None,
+        batch_size=None,
+        learning_rate=None,
+        epochs_iters=None,
+        device="gpu",
+        use_vdl=True,
+        save_dir=None,
+        **kwargs,
+    ):
+        raise_unsupported_api_error("compression", self.__class__)

+ 55 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py

@@ -0,0 +1,55 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+
+from ...base.register import register_model_info, register_suite_info
+from .model import BEVFusionModel
+from .runner import BEVFusionRunner
+from .config import BEVFusionConfig
+
+REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLE3D_PATH")
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+
+register_suite_info(
+    {
+        "suite_name": "BEVFusion",
+        "model": BEVFusionModel,
+        "runner": BEVFusionRunner,
+        "config": BEVFusionConfig,
+        "runner_root_path": REPO_ROOT_PATH,
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "BEVFusion",
+        "suite": "BEVFusion",
+        "config_path": osp.join(PDX_CONFIG_DIR, "BEVFusion.yaml"),
+        "auto_compression_config_path": osp.join(PDX_CONFIG_DIR, "None"),
+        "supported_apis": ["train", "evaluate", "export", "infer"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx"],
+            "dy2st": False,
+            "amp": ["O1", "O2"],
+        },
+        "supported_evaluate_opts": {"device": ["cpu", "gpu_nxcx"]},
+        "supported_infer_opts": {"device": ["cpu", "gpu"]},
+        "supported_dataset_types": [],
+        # Additional info
+        "infer_dir": "deploy/bevfusion",
+    }
+)

+ 104 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py

@@ -0,0 +1,104 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from ...base import BaseRunner
+
+
+def raise_unsupported_api_error(api_name, cls=None):
+    # TODO: Automatically extract `api_name` and `cls` from stack frame
+    if cls is not None:
+        name = f"{cls.__name__}.{api_name}"
+    else:
+        name = api_name
+    raise UnsupportedAPIError(f"The API `{name}` is not supported.")
+
+
+class UnsupportedAPIError(Exception):
+    pass
+
+
+class BEVFusionRunner(BaseRunner):
+    def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
+        args, env = self.distributed(device, ips, log_dir=save_dir)
+        cmd = [*args, "tools/train.py"]
+        if do_eval:
+            cmd.append("--do_eval")
+        cmd.extend(["--config", config_path, *cli_args])
+        return self.run_cmd(
+            cmd,
+            env=env,
+            switch_wdir=True,
+            echo=True,
+            silent=False,
+            capture_output=True,
+            log_path=self._get_train_log_path(save_dir),
+        )
+
+    def evaluate(self, config_path, cli_args, device, ips):
+        args, env = self.distributed(device, ips)
+        cmd = [*args, "tools/evaluate.py", "--config", config_path, *cli_args]
+        cp = self.run_cmd(
+            cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
+        )
+        if cp.returncode == 0:
+            metric_dict = _extract_eval_metrics(cp.stdout)
+            cp.metrics = metric_dict
+        return cp
+
+    def predict(self, config_path, cli_args, device):
+        raise_unsupported_api_error("predict", self.__class__)
+
+    def export(self, config_path, cli_args, device):
+        # `device` unused
+        cmd = [self.python, "tools/export.py", "--config", config_path, *cli_args]
+        return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+    def infer(self, config_path, cli_args, device, infer_dir, save_dir=None):
+        # `config_path` and `device` unused
+        cmd = [self.python, "infer.py", *cli_args]
+        python_infer_dir = os.path.join(infer_dir, "python")
+        cp = self.run_cmd(cmd, switch_wdir=python_infer_dir, echo=True, silent=False)
+        return cp
+
+    def compression(
+        self, config_path, train_cli_args, export_cli_args, device, train_save_dir
+    ):
+        raise_unsupported_api_error("compression", self.__class__)
+
+
+def _extract_eval_metrics(stdout):
+    import re
+
+    _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
+    metrics = ["mAP", "NDS"]
+    patterns = {}
+    for metric in metrics:
+        pattern = f"{metric}: (_dp)".replace("_dp", _DP)
+        patterns[metric] = pattern
+
+    metric_dict = dict()
+
+    # TODO: Use lazy version to make it more efficient
+    lines = stdout.splitlines()
+    for line in lines:
+        for m in patterns:
+            p = re.compile(patterns[m])
+            match = p.search(line)
+            if match:
+                metric_dict[m] = float(match.groups()[0])
+
+    return metric_dict

+ 215 - 0
paddlex/repo_apis/Paddle3D_api/configs/BEVFusion.yaml

@@ -0,0 +1,215 @@
+batch_size: 2 # 8 gpu, total bs=16
+epochs: 12
+
+train_dataset:
+  type: NuscenesMMDataset
+  ann_file: ./data/nuscenes/nuscenes_infos_train.pkl
+  data_root: ./data/nuscenes
+  class_names: [
+            'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
+            'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
+        ]
+  modality: multimodal
+  transforms:
+    - type: LoadPointsFromFile
+      load_dim: 5
+      use_dim: 5
+    - type: LoadPointsFromMultiSweeps
+      sweeps_num: 10
+    - type: LoadAnnotations3D
+      with_bbox_3d: true
+      with_label_3d: true
+    - type: LoadMultiViewImageFromFiles
+      project_pts_to_img_depth: true
+    - type: PointsRangeFilter
+      point_cloud_range: [-50, -50, -5, 50, 50, 3]
+    - type: SampleRangeFilter
+      point_cloud_range: [-50, -50, -5, 50, 50, 3]
+    - type: SampleNameFilter
+      classes: ['car', 'truck', 'trailer', 'bus', 'construction_vehicle',
+                'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone','barrier'
+                ]
+    - type: PointShuffle
+    - type: ResizeImage
+      img_scale: [[800, 448]]
+      keep_ratio: true
+    - type: NormalizeImage
+      mean: [123.675, 116.28, 103.53]
+      std: [58.395, 57.12, 57.375]
+      to_rgb: true
+    - type: PadImage
+      size_divisor: 32
+    - type: SampleFilterByKey
+      keys: ['img', 'img_depth', 'points', 'gt_bboxes_3d', 'gt_labels_3d']
+  mode: train
+
+val_dataset:
+  type: NuscenesMMDataset
+  ann_file: ./data/nuscenes/nuscenes_infos_val.pkl
+  data_root: ./data/nuscenes
+  class_names: [
+            'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
+            'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
+        ]
+  modality: multimodal
+  transforms:
+    - type: LoadPointsFromFile
+      load_dim: 5
+      use_dim: 5
+    - type: LoadPointsFromMultiSweeps
+      sweeps_num: 10
+    - type: LoadMultiViewImageFromFiles
+    - type: ResizeImage
+      img_scale: [[800, 448]]
+      keep_ratio: true
+    - type: NormalizeImage
+      mean: [123.675, 116.28, 103.53]
+      std: [58.395, 57.12, 57.375]
+      to_rgb: true
+    - type: PadImage
+      size_divisor: 32
+    - type: SampleFilterByKey
+      keys: ['points', 'img']
+  mode: val
+
+model:
+  type: BEVFFasterRCNN
+  se: True
+  lc_fusion: True
+  camera_stream: True
+  lss: False
+  grid: 0.5
+  num_views: 6
+  final_dim: [900, 1600]
+  downsample: 8
+  pts_voxel_layer:
+    max_num_points_in_voxel: 64
+    point_cloud_range: [-50., -50., -5., 50., 50., 3.]
+    voxel_size: [0.25, 0.25, 8.]
+    max_num_voxels: [30000, 40000]
+  pts_voxel_encoder:
+    type: HardVFE
+    in_channels: 4
+    feat_channels: [64, 64]
+    with_distance: False
+    voxel_size: [0.25, 0.25, 8]
+    with_cluster_center: True
+    with_voxel_center: True
+    point_cloud_range: [-50, -50, -5, 50, 50, 3]
+  pts_middle_encoder:
+    type: PointPillarsScatter
+    in_channels: 64
+    point_cloud_range: [-50, -50, -5, 50, 50, 3]
+    voxel_size: [0.25, 0.25, 8]
+  pts_backbone:
+    type: SecondBackbone
+    in_channels: 64
+    layer_nums: [3, 5, 5]
+    downsample_strides: [2, 2, 2]
+    out_channels: [64, 128, 256]
+  pts_neck:
+    type: SecondFPN
+    in_channels: [64, 128, 256]
+    upsample_strides: [1, 2, 4]
+    out_channels: [128, 128, 128]
+  img_backbone:
+      type: CBSwinTransformer
+      embed_dim: 96
+      depths: [2, 2, 6, 2]
+      num_heads: [3, 6, 12, 24]
+      window_size: 7
+      mlp_ratio: 4.0
+      qkv_bias: true
+      qk_scale: null
+      drop_rate: 0.0
+      attn_drop_rate: 0.0
+      drop_path_rate: 0.2
+      ape: false
+      patch_norm: true
+      out_indices: [0, 1, 2, 3]
+  img_neck:
+      type: FPNC
+      final_dim: [900, 1600]
+      downsample: 8
+      in_channels: [96, 192, 384, 768]
+      out_channels: 256
+      outC: 256
+      use_adp: true
+      num_outs: 5
+  pts_bbox_head:
+        type: Anchor3DHead
+        num_classes: 10
+        in_channels: 384
+        feat_channels: 384
+        use_direction_classifier: true
+        anchor_generator:
+          type: AlignedAnchor3DRangeGenerator
+          ranges: [[-49.6, -49.6, -1.80032795, 49.6, 49.6, -1.80032795],
+                    [-49.6, -49.6, -1.74440365, 49.6, 49.6, -1.74440365],
+                    [-49.6, -49.6, -1.68526504, 49.6, 49.6, -1.68526504],
+                    [-49.6, -49.6, -1.67339111, 49.6, 49.6, -1.67339111],
+                    [-49.6, -49.6, -1.61785072, 49.6, 49.6, -1.61785072],
+                    [-49.6, -49.6, -1.80984986, 49.6, 49.6, -1.80984986],
+                    [-49.6, -49.6, -1.763965, 49.6, 49.6, -1.763965]]
+          sizes: [[1.95017717, 4.60718145, 1.72270761],
+                   [2.4560939, 6.73778078, 2.73004906],
+                   [2.87427237, 12.01320693, 3.81509561],
+                   [0.60058911, 1.68452161, 1.27192197],
+                   [0.66344886, 0.7256437, 1.75748069],
+                   [0.39694519, 0.40359262, 1.06232151],
+                   [2.49008838, 0.48578221, 0.98297065]]
+          custom_values: [0, 0]
+          rotations: [0, 1.57]
+          reshape_out: true
+        assigner_per_size: false
+        diff_rad_by_sin: true
+        dir_offset: 0.7854  # pi/4
+        dir_limit_offset: 0
+        bbox_coder:
+          type: DeltaXYZWLHRBBoxCoder
+          code_size: 9
+        loss_cls:
+            type: WeightedFocalLoss
+            use_sigmoid: true
+            gamma: 2.0
+            alpha: 0.25
+            loss_weight: 1.0
+        loss_bbox:
+            type: SmoothL1Loss
+            beta: 0.1111111111111111
+            loss_weight: 1.0
+        loss_dir:
+            type: CrossEntropyLoss
+            use_sigmoid: false
+            loss_weight: 0.2
+        use_sigmoid_cls: true
+        train_cfg:
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
+          pos_weight: -1
+        test_cfg:
+          use_rotate_nms: true
+          nms_across_levels: false
+          nms_pre: 1000
+          nms_thr: 0.2
+          score_thr: 0.05
+          min_bbox_size: 0
+          max_num: 500
+
+optimizer:
+  type: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.05
+  grad_clip:
+    type: ClipGradByGlobalNorm
+    clip_norm: 35
+
+lr_scheduler:
+  type: LinearWarmup
+  learning_rate:
+    type: MultiStepDecay
+    milestones: [6032, 8669] # [879*8e-1000, 879*11e-1000]
+    learning_rate: 0.001
+  warmup_steps: 1000
+  start_lr: 1.0e-6
+  end_lr: 0.001

+ 144 - 0
paddlex/repo_apis/Paddle3D_api/pp3d_config.py

@@ -0,0 +1,144 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import yaml
+from ...utils.misc import abspath
+
+from ..base import BaseConfig
+
+
+class PP3DConfig(BaseConfig):
+    # Refer to https://github.com/PaddlePaddle/Paddle3D/blob/release/1.0/paddle3d/apis/config.py
+    def update(self, dict_like_obj):
+        def _merge_config_dicts(dict_from, dict_to):
+            # According to
+            # https://github.com/PaddlePaddle/Paddle3D/blob/3cf884ecbc94330be0e2db780434bb60b9b4fe8c/paddle3d/apis/config.py#L90
+            for key, val in dict_from.items():
+                if isinstance(val, dict) and key in dict_to:
+                    dict_to[key] = _merge_config_dicts(val, dict_to[key])
+                else:
+                    dict_to[key] = val
+            return dict_to
+
+        dict_ = _merge_config_dicts(dict_like_obj, self.dict)
+        self.reset_from_dict(dict_)
+
+    def load(self, config_path):
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        dict_ = dic
+        self.reset_from_dict(dict_)
+
+    def dump(self, config_path):
+        with open(config_path, "w", encoding="utf-8") as f:
+            yaml.dump(self.dict, f)
+
+    def update_learning_rate(self, learning_rate):
+        if "lr_scheduler" not in self:
+            raise RuntimeError(
+                "Not able to update learning rate, because no LR scheduler config was found."
+            )
+
+        # Some lr_scheduler in Paddle3D has not learning_rate parameter.
+        if self.lr_scheduler["type"] == "OneCycle":
+            self.lr_scheduler["lr_max"] = learning_rate
+        elif self.lr_scheduler["type"] == "OneCycleWarmupDecayLr":
+            self.lr_scheduler["base_learning_rate"] = learning_rate
+        else:
+            self.lr_scheduler["learning_rate"] = learning_rate
+
+    def update_batch_size(self, batch_size, mode="train"):
+        if mode == "train":
+            self.set_val("batch_size", batch_size)
+        else:
+            raise ValueError(
+                f"Setting `batch_size` in {repr(mode)} mode is not supported."
+            )
+
+    def update_epochs(self, epochs, mode="train"):
+        if mode == "train":
+            self.set_val("epochs", epochs)
+        else:
+            raise ValueError(f"Setting `epochs` in {repr(mode)} mode is not supported.")
+
+    def update_pretrained_weights(self, weight_path, is_backbone=False):
+        raise NotImplementedError
+
+    def get_epochs_iters(self):
+        if "iters" in self:
+            return self.iters
+        else:
+            assert "epochs" in self
+            return self.epochs
+
+    def get_learning_rate(self):
+        if "lr_scheduler" not in self or "learning_rate" not in self.lr_scheduler:
+            # Default lr
+            return 0.0001
+        else:
+            lr = self.lr_scheduler["learning_rate"]
+            while isinstance(lr, dict):
+                lr = lr["learning_rate"]
+            return lr
+
+    def get_batch_size(self, mode="train"):
+        if "batch_size" in self:
+            return self.batch_size
+        else:
+            # Default batch size
+            return 1
+
+    def get_qat_epochs_iters(self):
+        assert (
+            "finetune_config" in self
+        ), "QAT training yaml should contain finetune_config key"
+        if "iters" in self.finetune_config:
+            return self.finetune_config["iters"]
+        else:
+            assert "epochs" in self.finetune_config
+            return self.finetune_config["epochs"]
+
+    def get_qat_learning_rate(self):
+        assert (
+            "finetune_config" in self
+        ), "QAT training yaml should contain finetune_config key"
+        cfg = self.finetune_config
+        if "lr_scheduler" in cfg or "learning_rate" not in cfg.lr_scheduler:
+            # Default lr
+            return 1.25e-4
+        else:
+            lr = cfg.lr_scheduler["learning_rate"]
+            while isinstance(lr, dict):
+                lr = lr["learning_rate"]
+            return lr
+
+    def update_warmup_steps(self, steps):
+        self.lr_scheduler["warmup_steps"] = steps
+
+    def update_end_lr(self, learning_rate):
+        self.lr_scheduler["end_lr"] = learning_rate
+
+    def update_iters(self, iters):
+        self.set_val("iters", iters)
+        if "epochs" in self:
+            self.set_val("epochs", None)
+
+    def update_finetune_iters(self, iters):
+        self.finetune_config["iters"] = iters
+        if "epochs" in self.finetune_config:
+            self.finetune_config["epochs"] = None
+
+    def update_save_dir(self, save_dir: str):
+        self["save_dir"] = abspath(save_dir)

+ 13 - 0
paddlex/repo_manager/meta.py

@@ -25,6 +25,7 @@ REPO_NAMES = [
     "PaddleSeg",
     "PaddleSeg",
     "PaddleNLP",
     "PaddleNLP",
     "PaddleTS",
     "PaddleTS",
+    "Paddle3D",
     "PaddleVideo",
     "PaddleVideo",
 ]
 ]
 
 
@@ -131,6 +132,18 @@ REPO_META = {
         "path_env": "PADDLE_PDX_PADDLEMIX_PATH",
         "path_env": "PADDLE_PDX_PADDLEMIX_PATH",
         "requires": ["PaddleNLP"],
         "requires": ["PaddleNLP"],
     },
     },
+    "Paddle3D": {
+        "git_path": "/PaddlePaddle/Paddle3D.git",
+        "platform": "github",
+        "branch": "develop",
+        "pkg_name": "paddle3d",
+        "lib_name": "paddle3d",
+        "pdx_pkg_name": "Paddle3D_api",
+        "editable": False,
+        "path_env": "PADDLE_PDX_PADDLE3D_PATH",
+        "requires": ["PaddleSeg", "PaddleDetection"],
+        "main_req_file": "requirements_pdx.txt",
+    },
     "PaddleVideo": {
     "PaddleVideo": {
         "git_path": "/PaddlePaddle/PaddleVideo.git",
         "git_path": "/PaddlePaddle/PaddleVideo.git",
         "platform": "github",
         "platform": "github",

+ 4 - 0
requirements.txt

@@ -38,6 +38,7 @@ erniebot-agent == 0.5.2
 unstructured
 unstructured
 networkx
 networkx
 faiss-cpu
 faiss-cpu
+######## For Vidio #######
 decord==0.6.0; platform_machine == 'x86_64'
 decord==0.6.0; platform_machine == 'x86_64'
 ######## For NLP Tokenizer #######
 ######## For NLP Tokenizer #######
 jieba
 jieba
@@ -46,3 +47,6 @@ jinja2
 regex
 regex
 ######## For Speech #######
 ######## For Speech #######
 soundfile
 soundfile
+######## For 3D BEVFusion #######
+nuscenes-devkit
+pyquaternion