소스 검색

add meter_reader and batch_predict

FlyingQianMM 5 년 전
부모
커밋
fbbdaf3bb0

+ 22 - 18
deploy/cpp/src/paddlex.cpp

@@ -30,13 +30,14 @@ void Model::create_predictor(const std::string& model_dir,
   std::string model_file = model_dir + OS_PATH_SEP + "__model__";
   std::string params_file = model_dir + OS_PATH_SEP + "__params__";
 #ifdef WITH_ENCRYPTION
-  if (key != ""){
+  if (key != "") {
     model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
     params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
-    paddle_security_load_model(&config, key.c_str(), model_file.c_str(), params_file.c_str());
+    paddle_security_load_model(
+      &config, key.c_str(), model_file.c_str(), params_file.c_str());
   }
 #endif
-  if (key == ""){
+  if (key == "") {
     config.SetModel(model_file, params_file);
   }
   if (use_gpu) {
@@ -67,9 +68,10 @@ bool Model::load_config(const std::string& model_dir) {
   name = config["Model"].as<std::string>();
   std::string version = config["version"].as<std::string>();
   if (version[0] == '0') {
-    std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, deployment "
-              << "cannot be done, please refer to "
-              << "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/tutorials/deploy/upgrade_version.md "
+    std::cerr << "[Init] Version of the loaded model is lower than 1.0.0,"
+              << " deployment cannot be done, please refer to "
+              << "https://github.com/PaddlePaddle/PaddleX/blob/develop/"
+              << "docs/tutorials/deploy/upgrade_version.md "
               << "to transfer version."
               << std::endl;
     return false;
@@ -332,18 +334,20 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
       inputs_.im_size_before_resize_.pop_back();
       auto resize_w = before_shape[0];
       auto resize_h = before_shape[1];
-      cv::resize(mask_label,
-                 mask_label,
-                 cv::Size(resize_h, resize_w),
-                 0,
-                 0,
-                 cv::INTER_NEAREST);
-      cv::resize(mask_score,
-                 mask_score,
-                 cv::Size(resize_h, resize_w),
-                 0,
-                 0,
-                 cv::INTER_NEAREST);
+      if (mask_label->rows != resize_h || mask_label->cols != resize_w) {
+        cv::resize(mask_label,
+                   mask_label,
+                   cv::Size(resize_h, resize_w),
+                   0,
+                   0,
+                   cv::INTER_NEAREST);
+        cv::resize(mask_score,
+                   mask_score,
+                   cv::Size(resize_h, resize_w),
+                   0,
+                   0,
+                   cv::INTER_NEAREST);
+      }
     }
     ++idx;
   }

+ 4 - 2
deploy/cpp/src/transforms.cpp

@@ -150,8 +150,10 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) {
   data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("resize");
 
-  cv::resize(
-      *im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
+  if (im->rows != height_ || im->cols != width_) {
+    cv::resize(
+        *im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
+  }
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;
   return true;

+ 309 - 0
examples/meter_reader/deploy/cpp/CMakeLists.txt

@@ -0,0 +1,309 @@
+cmake_minimum_required(VERSION 3.0)
+project(PaddleX CXX C)
+
+option(WITH_MKL        "Compile meter with MKL/OpenBlas support,defaultuseMKL."          ON)
+option(WITH_GPU        "Compile meter with GPU/CPU, default use CPU."                    ON)
+option(WITH_STATIC_LIB "Compile meter with static/shared library, default use static."   OFF)
+option(WITH_TENSORRT "Compile meter with TensorRT."   OFF)
+option(WITH_ENCRYPTION "Compile meter with encryption tool."   OFF)
+
+SET(TENSORRT_DIR "" CACHE PATH "Location of libraries")
+SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
+SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
+SET(ENCRYPTION_DIR"" CACHE PATH "Location of libraries")
+SET(CUDA_LIB "" CACHE PATH "Location of libraries")
+
+if (NOT WIN32)
+    set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
+    set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
+    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/meter)
+else()
+    set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddlex_inference)
+    set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddlex_inference)
+    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/paddlex_inference)
+endif()
+
+if (NOT WIN32)
+    SET(YAML_BUILD_TYPE ON CACHE BOOL "yaml build shared library.")
+else()
+    SET(YAML_BUILD_TYPE OFF CACHE BOOL "yaml build shared library.")
+endif()
+include(cmake/yaml-cpp.cmake)
+
+include_directories("${CMAKE_SOURCE_DIR}/")
+include_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/src/ext-yaml-cpp/include")
+link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
+
+macro(safe_set_static_flag)
+    foreach(flag_var
+        CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
+        CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
+      if(${flag_var} MATCHES "/MD")
+        string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
+      endif(${flag_var} MATCHES "/MD")
+    endforeach(flag_var)
+endmacro()
+
+
+if (WITH_ENCRYPTION)
+add_definitions( -DWITH_ENCRYPTION=${WITH_ENCRYPTION})
+endif()
+
+if (WITH_MKL)
+    ADD_DEFINITIONS(-DUSE_MKL)
+endif()
+
+if (NOT DEFINED PADDLE_DIR OR ${PADDLE_DIR} STREQUAL "")
+    message(FATAL_ERROR "please set PADDLE_DIR with -DPADDLE_DIR=/path/paddle_influence_dir")
+endif()
+
+if (NOT DEFINED OPENCV_DIR OR ${OPENCV_DIR} STREQUAL "")
+    message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv")
+endif()
+
+include_directories("${CMAKE_SOURCE_DIR}/")
+include_directories("${PADDLE_DIR}/")
+include_directories("${PADDLE_DIR}/third_party/install/protobuf/include")
+include_directories("${PADDLE_DIR}/third_party/install/glog/include")
+include_directories("${PADDLE_DIR}/third_party/install/gflags/include")
+include_directories("${PADDLE_DIR}/third_party/install/xxhash/include")
+if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/include")
+    include_directories("${PADDLE_DIR}/third_party/install/snappy/include")
+endif()
+if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/include")
+    include_directories("${PADDLE_DIR}/third_party/install/snappystream/include")
+endif()
+# zlib does not exist in 1.8.1
+if (EXISTS "${PADDLE_DIR}/third_party/install/zlib/include")
+    include_directories("${PADDLE_DIR}/third_party/install/zlib/include")
+endif()
+
+include_directories("${PADDLE_DIR}/third_party/boost")
+include_directories("${PADDLE_DIR}/third_party/eigen3")
+
+if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
+    link_directories("${PADDLE_DIR}/third_party/install/snappy/lib")
+endif()
+if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
+    link_directories("${PADDLE_DIR}/third_party/install/snappystream/lib")
+endif()
+
+if (EXISTS "${PADDLE_DIR}/third_party/install/zlib/lib")
+    link_directories("${PADDLE_DIR}/third_party/install/zlib/lib")
+endif()
+
+link_directories("${PADDLE_DIR}/third_party/install/protobuf/lib")
+link_directories("${PADDLE_DIR}/third_party/install/glog/lib")
+link_directories("${PADDLE_DIR}/third_party/install/gflags/lib")
+link_directories("${PADDLE_DIR}/third_party/install/xxhash/lib")
+link_directories("${PADDLE_DIR}/paddle/lib/")
+link_directories("${CMAKE_CURRENT_BINARY_DIR}")
+
+if (WIN32)
+  include_directories("${PADDLE_DIR}/paddle/fluid/inference")
+  include_directories("${PADDLE_DIR}/paddle/include")
+  link_directories("${PADDLE_DIR}/paddle/fluid/inference")
+  find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
+  unset(OpenCV_DIR CACHE)
+else ()
+  find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH)
+  include_directories("${PADDLE_DIR}/paddle/include")
+  link_directories("${PADDLE_DIR}/paddle/lib")
+endif ()
+include_directories(${OpenCV_INCLUDE_DIRS})
+
+if (WIN32)
+    add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
+    find_package(OpenMP REQUIRED)
+    if (OPENMP_FOUND)
+        message("OPENMP FOUND")
+        set(CMAKE_C_FLAGS_DEBUG   "${CMAKE_C_FLAGS_DEBUG} ${OpenMP_C_FLAGS}")
+        set(CMAKE_C_FLAGS_RELEASE  "${CMAKE_C_FLAGS_RELEASE} ${OpenMP_C_FLAGS}")
+        set(CMAKE_CXX_FLAGS_DEBUG  "${CMAKE_CXX_FLAGS_DEBUG} ${OpenMP_CXX_FLAGS}")
+        set(CMAKE_CXX_FLAGS_RELEASE   "${CMAKE_CXX_FLAGS_RELEASE} ${OpenMP_CXX_FLAGS}")
+    endif()
+    set(CMAKE_C_FLAGS_DEBUG   "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
+    set(CMAKE_C_FLAGS_RELEASE  "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
+    set(CMAKE_CXX_FLAGS_DEBUG  "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
+    set(CMAKE_CXX_FLAGS_RELEASE   "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
+    if (WITH_STATIC_LIB)
+        safe_set_static_flag()
+        add_definitions(-DSTATIC_LIB)
+    endif()
+else()
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o2 -fopenmp -std=c++11")
+    set(CMAKE_STATIC_LIBRARY_PREFIX "")
+endif()
+
+if (WITH_GPU)
+    if (NOT DEFINED CUDA_LIB OR ${CUDA_LIB} STREQUAL "")
+        message(FATAL_ERROR "please set CUDA_LIB with -DCUDA_LIB=/path/cuda/lib64")
+    endif()
+    if (NOT WIN32)
+        if (NOT DEFINED CUDNN_LIB)
+            message(FATAL_ERROR "please set CUDNN_LIB with -DCUDNN_LIB=/path/cudnn/")
+        endif()
+    endif(NOT WIN32)
+endif()
+
+
+if (NOT WIN32)
+  if (WITH_TENSORRT AND WITH_GPU)
+      include_directories("${TENSORRT_DIR}/include")
+      link_directories("${TENSORRT_DIR}/lib")
+  endif()
+endif(NOT WIN32)
+
+if (NOT WIN32)
+    set(NGRAPH_PATH "${PADDLE_DIR}/third_party/install/ngraph")
+    if(EXISTS ${NGRAPH_PATH})
+        include(GNUInstallDirs)
+        include_directories("${NGRAPH_PATH}/include")
+        link_directories("${NGRAPH_PATH}/${CMAKE_INSTALL_LIBDIR}")
+        set(NGRAPH_LIB ${NGRAPH_PATH}/${CMAKE_INSTALL_LIBDIR}/libngraph${CMAKE_SHARED_LIBRARY_SUFFIX})
+    endif()
+endif()
+
+if(WITH_MKL)
+  include_directories("${PADDLE_DIR}/third_party/install/mklml/include")
+  if (WIN32)
+    set(MATH_LIB ${PADDLE_DIR}/third_party/install/mklml/lib/mklml.lib
+            ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5md.lib)
+  else ()
+    set(MATH_LIB ${PADDLE_DIR}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+            ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+    execute_process(COMMAND cp -r ${PADDLE_DIR}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} /usr/lib)
+  endif ()
+  set(MKLDNN_PATH "${PADDLE_DIR}/third_party/install/mkldnn")
+  if(EXISTS ${MKLDNN_PATH})
+    include_directories("${MKLDNN_PATH}/include")
+    if (WIN32)
+      set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
+    else ()
+      set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+    endif ()
+  endif()
+else()
+  set(MATH_LIB ${PADDLE_DIR}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
+endif()
+
+if (WIN32)
+    if(EXISTS "${PADDLE_DIR}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}")
+        set(DEPS
+            ${PADDLE_DIR}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
+    else()
+        set(DEPS
+            ${PADDLE_DIR}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
+    endif()
+endif()
+
+if(WITH_STATIC_LIB)
+    set(DEPS
+        ${PADDLE_DIR}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
+else()
+    if (NOT WIN32)
+      set(DEPS
+          ${PADDLE_DIR}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
+    else()
+      set(DEPS
+          ${PADDLE_DIR}/paddle/lib/paddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
+    endif()
+endif()
+
+if (NOT WIN32)
+    set(DEPS ${DEPS}
+        ${MATH_LIB} ${MKLDNN_LIB}
+        glog gflags protobuf z xxhash yaml-cpp
+        )
+    if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
+        set(DEPS ${DEPS} snappystream)
+    endif()
+    if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
+        set(DEPS ${DEPS} snappy)
+    endif()
+else()
+    set(DEPS ${DEPS}
+        ${MATH_LIB} ${MKLDNN_LIB}
+        glog gflags_static libprotobuf xxhash libyaml-cppmt)
+
+    if (EXISTS "${PADDLE_DIR}/third_party/install/zlib/lib")
+      set(DEPS ${DEPS} zlibstatic)
+    endif()
+    set(DEPS ${DEPS} libcmt shlwapi)
+    if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
+        set(DEPS ${DEPS} snappy)
+    endif()
+    if (EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
+        set(DEPS ${DEPS} snappystream)
+    endif()
+endif(NOT WIN32)
+
+if(WITH_GPU)
+  if(NOT WIN32)
+    if (WITH_TENSORRT)
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
+      set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
+    endif()
+    set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
+    set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
+  else()
+    set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
+    set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} )
+    set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX})
+  endif()
+endif()
+
+if(WITH_ENCRYPTION)
+  if(NOT WIN32)
+      include_directories("${ENCRYPTION_DIR}/include")
+      link_directories("${ENCRYPTION_DIR}/lib")
+      set(DEPS ${DEPS} ${ENCRYPTION_DIR}/lib/libpmodel-decrypt${CMAKE_SHARED_LIBRARY_SUFFIX})
+  else()
+      include_directories("${ENCRYPTION_DIR}/include")
+      link_directories("${ENCRYPTION_DIR}/lib")
+      set(DEPS ${DEPS} ${ENCRYPTION_DIR}/lib/pmodel-decrypt${CMAKE_STATIC_LIBRARY_SUFFIX})
+  endif()
+endif()
+
+if (NOT WIN32)
+    set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread")
+    set(DEPS ${DEPS} ${EXTERNAL_LIB})
+endif()
+
+set(DEPS ${DEPS} ${OpenCV_LIBS})
+add_library(paddlex_inference SHARED src/visualize src/transforms.cpp src/paddlex.cpp)
+ADD_DEPENDENCIES(paddlex_inference ext-yaml-cpp)
+target_link_libraries(paddlex_inference ${DEPS})
+
+add_executable(meter meter/meter.cpp meter/global.cpp meter/readvalue.cpp src/transforms.cpp src/paddlex.cpp src/visualize.cpp)
+ADD_DEPENDENCIES(meter ext-yaml-cpp)
+target_link_libraries(meter ${DEPS})
+
+if (WIN32 AND WITH_MKL)
+    add_custom_command(TARGET meter POST_BUILD
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mklml/lib/mklml.dll ./mklml.dll
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5md.dll ./libiomp5md.dll
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mkldnn/lib/mkldnn.dll ./mkldnn.dll
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mklml/lib/mklml.dll ./release/mklml.dll
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mklml/lib/libiomp5md.dll ./release/libiomp5md.dll
+        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_DIR}/third_party/install/mkldnn/lib/mkldnn.dll ./release/mkldnn.dll
+    )
+    # for encryption
+    if (EXISTS "${ENCRYPTION_DIR}/lib/pmodel-decrypt.dll")
+        add_custom_command(TARGET meter POST_BUILD
+            COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ENCRYPTION_DIR}/lib/pmodel-decrypt.dll ./pmodel-decrypt.dll
+            COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ENCRYPTION_DIR}/lib/pmodel-decrypt.dll ./release/pmodel-decrypt.dll
+        )
+    endif()
+endif()
+
+file(COPY  "${CMAKE_SOURCE_DIR}/include/paddlex/visualize.h"
+DESTINATION  "${CMAKE_BINARY_DIR}/include/"  )
+file(COPY  "${CMAKE_SOURCE_DIR}/include/paddlex/config_parser.h"
+DESTINATION  "${CMAKE_BINARY_DIR}/include/"  )
+file(COPY  "${CMAKE_SOURCE_DIR}/include/paddlex/transforms.h"
+DESTINATION  "${CMAKE_BINARY_DIR}/include/"  )
+file(COPY  "${CMAKE_SOURCE_DIR}/include/paddlex/results.h"
+DESTINATION  "${CMAKE_BINARY_DIR}/include/"  )
+file(COPY  "${CMAKE_SOURCE_DIR}/include/paddlex/paddlex.h"
+DESTINATION  "${CMAKE_BINARY_DIR}/include/"  )

+ 34 - 0
examples/meter_reader/deploy/cpp/meter/global.cpp

@@ -0,0 +1,34 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include <iostream>
+#include <vector>
+#include <limits>
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/highgui.hpp>
+#include <opencv2/core/core.hpp>
+
+#include "meter/global.h"
+
+std::vector<int> IMAGE_SHAPE = {1920, 1080};
+std::vector<int> RESULT_SHAPE = {1280, 720};
+std::vector<int> METER_SHAPE = {512, 512};
+
+#define METER_TYPE_NUM 2
+MeterConfig_T meter_config[METER_TYPE_NUM] = {
+{25.0f/50.0f, 25.0f,  "(MPa)"},
+{1.6f/32.0f,  1.6f,   "(MPa)"}
+};

+ 30 - 0
examples/meter_reader/deploy/cpp/meter/global.h

@@ -0,0 +1,30 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <vector>
+
+typedef struct MeterConfig {
+  float scale_value;
+  float range;
+  char  str[10];
+} MeterConfig_T;
+
+extern std::vector<int> IMAGE_SHAPE;
+extern std::vector<int> RESULT_SHAPE;
+extern std::vector<int> METER_SHAPE;
+extern MeterConfig_T meter_config[];
+
+#define TYPE_THRESHOLD 40

+ 297 - 0
examples/meter_reader/deploy/cpp/meter/meter.cpp

@@ -0,0 +1,297 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <glog/logging.h>
+#include <omp.h>
+
+#include <algorithm>
+#include <chrono>  // NOLINT
+#include <iostream>
+#include <vector>
+#include <utility>
+#include <limits>
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/highgui.hpp>
+#include <opencv2/core/core.hpp>
+
+#include "meter/global.h"
+#include "meter/readvalue.h"
+#include "include/paddlex/paddlex.h"
+#include "include/paddlex/visualize.h"
+
+using namespace std::chrono;  // NOLINT
+
+DEFINE_string(det_model_dir, "", "Path of detection inference model");
+DEFINE_string(seg_model_dir, "", "Path of segmentation inference model");
+DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
+DEFINE_bool(use_trt, false, "Infering with TensorRT");
+DEFINE_bool(use_camera, false, "Infering with Camera");
+DEFINE_bool(use_erode, true, "Eroding predicted label map");
+DEFINE_int32(gpu_id, 0, "GPU card id");
+DEFINE_int32(camera_id, 0, "Camera id");
+DEFINE_int32(thread_num,
+             omp_get_num_procs(),
+             "Number of preprocessing threads");
+DEFINE_int32(erode_kernel, true, "Eroding kernel size");
+DEFINE_int32(seg_batch_size, 2, "Batch size of segmentation infering");
+DEFINE_string(det_key, "", "Detector key of encryption");
+DEFINE_string(seg_key, "", "Segmenter model key of encryption");
+DEFINE_string(image, "", "Path of test image file");
+DEFINE_string(image_list, "", "Path of test image list file");
+DEFINE_string(save_dir, "output", "Path to save visualized image");
+
+void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
+             PaddleX::Model *seg_model, const std::string save_dir,
+             const std::string image_path, const bool use_erode,
+             const int erode_kernel, const int thread_num,
+             const int seg_batch_size) {
+  PaddleX::DetResult det_result;
+  det_model->predict(input_image, &det_result);
+
+  int meter_num = det_result.boxes.size();
+  if (!meter_num) {
+      std::cout << "Don't find any meter." << std::endl;
+      return;
+  }
+
+  std::vector<std::vector<int64_t>> seg_result(meter_num);
+  for (int i = 0; i < meter_num; i += seg_batch_size) {
+    int im_vec_size =
+      std::min(static_cast<int>(meter_num), i + seg_batch_size);
+    std::vector<cv::Mat> meters_image(im_vec_size - i);
+    int batch_thread_num = std::min(thread_num, im_vec_size - i);
+    #pragma omp parallel for num_threads(batch_thread_num)
+    for (int j = i; j < im_vec_size; ++j) {
+      int left = static_cast<int>(det_result.boxes[j].coordinate[0]);
+      int top = static_cast<int>(det_result.boxes[j].coordinate[1]);
+      int width = static_cast<int>(det_result.boxes[j].coordinate[2]);
+      int height = static_cast<int>(det_result.boxes[j].coordinate[3]);
+      int right = left + width - 1;
+      int bottom = top + height - 1;
+
+      cv::Mat sub_image = input_image(
+        cv::Range(top, bottom + 1), cv::Range(left, right + 1));
+      float scale_x =
+        static_cast<float>(METER_SHAPE[0]) / static_cast<float>(sub_image.cols);
+      float scale_y =
+        static_cast<float>(METER_SHAPE[1]) / static_cast<float>(sub_image.rows);
+      cv::resize(sub_image,
+                 sub_image,
+                 cv::Size(),
+                 scale_x,
+                 scale_y,
+                 cv::INTER_LINEAR);
+      meters_image[j - i] = std::move(sub_image);
+    }
+    std::vector<PaddleX::SegResult> batch_result(im_vec_size - i);
+    seg_model->predict(meters_image, &batch_result, batch_thread_num);
+    #pragma omp parallel for num_threads(batch_thread_num)
+    for (int j = i; j < im_vec_size; ++j) {
+      if (use_erode) {
+        cv::Mat kernel(4, 4, CV_8U, cv::Scalar(1));
+        std::vector<uint8_t> label_map(
+          batch_result[j - i].label_map.data.begin(),
+          batch_result[j - i].label_map.data.end());
+        cv::Mat mask(batch_result[j - i].label_map.shape[0],
+                     batch_result[j - i].label_map.shape[1],
+                     CV_8UC1,
+                     label_map.data());
+        cv::erode(mask, mask, kernel);
+        std::vector<int64_t> map;
+        if (mask.isContinuous()) {
+            map.assign(mask.data, mask.data + mask.total() * mask.channels());
+        } else {
+          for (int r = 0; r < mask.rows; r++) {
+            map.insert(map.end(),
+                       mask.ptr<int64_t>(r),
+                       mask.ptr<int64_t>(r) + mask.cols * mask.channels());
+          }
+        }
+        seg_result[j] = std::move(map);
+      } else {
+        seg_result[j] = std::move(batch_result[j - i].label_map.data);
+      }
+    }
+  }
+
+  std::vector<READ_RESULT> read_results(meter_num);
+  int all_thread_num = std::min(thread_num, meter_num);
+  read_process(seg_result, &read_results, all_thread_num);
+
+  cv::Mat output_image = input_image.clone();
+  for (int i = 0; i < meter_num; i++) {
+    float result = 0;;
+    if (read_results[i].scale_num > TYPE_THRESHOLD) {
+      result = read_results[i].scales * meter_config[0].scale_value;
+    } else {
+      result = read_results[i].scales * meter_config[1].scale_value;
+    }
+    std::cout << "-- Meter " << i
+              << " -- result: " << result
+              << " --" << std::endl;
+
+    int lx = static_cast<int>(det_result.boxes[i].coordinate[0]);
+    int ly = static_cast<int>(det_result.boxes[i].coordinate[1]);
+    int w = static_cast<int>(det_result.boxes[i].coordinate[2]);
+    int h = static_cast<int>(det_result.boxes[i].coordinate[3]);
+
+    cv::Rect bounding_box = cv::Rect(lx, ly, w, h) &
+        cv::Rect(0, 0, output_image.cols, output_image.rows);
+    if (w > 0 && h > 0) {
+      cv::Scalar color = cv::Scalar(237, 189, 101);
+      cv::rectangle(output_image, bounding_box, color);
+      cv::rectangle(output_image,
+                    cv::Point2d(lx, ly),
+                    cv::Point2d(lx + w, ly - 30),
+                    color, -1);
+
+      std::string class_name = "Meter";
+      cv::putText(output_image,
+                  class_name + " " + std::to_string(result),
+                  cv::Point2d(lx, ly-5),
+                  cv::FONT_HERSHEY_SIMPLEX,
+                  1, cv::Scalar(255, 255, 255), 2);
+    }
+  }
+
+  cv::Mat result_image;
+  cv::Size resize_size(RESULT_SHAPE[0], RESULT_SHAPE[1]);
+  cv::resize(output_image, result_image, resize_size, 0, 0, cv::INTER_LINEAR);
+  std::string save_path = PaddleX::generate_save_path(save_dir, image_path);
+  cv::imwrite(save_path, result_image);
+
+  return;
+}
+
+
+int main(int argc, char **argv) {
+  google::ParseCommandLineFlags(&argc, &argv, true);
+  if (FLAGS_det_model_dir == "") {
+    std::cerr << "--det_model_dir need to be defined" << std::endl;
+    return -1;
+  }
+  if (FLAGS_seg_model_dir == "") {
+    std::cerr << "--seg_model_dir need to be defined" << std::endl;
+    return -1;
+  }
+  if (FLAGS_image == "" & FLAGS_image_list == "" & FLAGS_use_camera == false) {
+    std::cerr << "--image or --image_list need to be defined "
+              << "when the camera is not been used" << std::endl;
+    return -1;
+  }
+
+  // 加载模型
+  PaddleX::Model det_model;
+  det_model.Init(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_use_trt,
+                 FLAGS_gpu_id, FLAGS_det_key);
+  PaddleX::Model seg_model;
+  seg_model.Init(FLAGS_seg_model_dir, FLAGS_use_gpu, FLAGS_use_trt,
+                 FLAGS_gpu_id, FLAGS_seg_key);
+
+  double total_running_time_s = 0.0;
+  double total_imread_time_s = 0.0;
+  int imgs = 1;
+  if (FLAGS_use_camera) {
+    cv::VideoCapture cap(FLAGS_camera_id);
+    cap.set(CV_CAP_PROP_FRAME_WIDTH, IMAGE_SHAPE[0]);
+    cap.set(CV_CAP_PROP_FRAME_HEIGHT, IMAGE_SHAPE[1]);
+    if (!cap.isOpened()) {
+      std::cout << "Open the camera unsuccessfully." << std::endl;
+      return -1;
+    }
+    std::cout << "Open the camera successfully." << std::endl;
+
+    while (1) {
+      auto start = system_clock::now();
+      cv::Mat im;
+      cap >> im;
+      auto imread_end = system_clock::now();
+      std::cout << "-------------------------" << std::endl;
+      std::cout << "Got a camera image." << std::endl;
+      std::string ext_name = ".jpg";
+      predict(im, &det_model, &seg_model, FLAGS_save_dir,
+              std::to_string(imgs) + ext_name, FLAGS_use_erode,
+              FLAGS_erode_kernel, FLAGS_thread_num, FLAGS_seg_batch_size);
+      imgs++;
+      auto imread_duration = duration_cast<microseconds>(imread_end - start);
+      total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                             microseconds::period::num /
+                             microseconds::period::den;
+
+      auto end = system_clock::now();
+      auto duration = duration_cast<microseconds>(end - start);
+      total_running_time_s += static_cast<double>(duration.count()) *
+                              microseconds::period::num /
+                              microseconds::period::den;
+    }
+    cap.release();
+    cv::destroyAllWindows();
+  } else {
+    if (FLAGS_image_list != "") {
+      std::ifstream inf(FLAGS_image_list);
+      if (!inf) {
+        std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
+        return -1;
+      }
+      std::string image_path;
+      while (getline(inf, image_path)) {
+        auto start = system_clock::now();
+        cv::Mat im = cv::imread(image_path, 1);
+        imgs++;
+        auto imread_end = system_clock::now();
+
+        predict(im, &det_model, &seg_model, FLAGS_save_dir,
+                image_path, FLAGS_use_erode, FLAGS_erode_kernel,
+                FLAGS_thread_num, FLAGS_seg_batch_size);
+
+        auto imread_duration = duration_cast<microseconds>(imread_end - start);
+        total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                               microseconds::period::num /
+                               microseconds::period::den;
+
+        auto end = system_clock::now();
+        auto duration = duration_cast<microseconds>(end - start);
+        total_running_time_s += static_cast<double>(duration.count()) *
+                                microseconds::period::num /
+                                microseconds::period::den;
+      }
+    } else {
+      auto start = system_clock::now();
+      cv::Mat im = cv::imread(FLAGS_image, 1);
+      auto imread_end = system_clock::now();
+
+      predict(im, &det_model, &seg_model, FLAGS_save_dir,
+              FLAGS_image, FLAGS_use_erode, FLAGS_erode_kernel,
+              FLAGS_thread_num, FLAGS_seg_batch_size);
+
+      auto imread_duration = duration_cast<microseconds>(imread_end - start);
+      total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                             microseconds::period::num /
+                             microseconds::period::den;
+
+      auto end = system_clock::now();
+      auto duration = duration_cast<microseconds>(end - start);
+      total_running_time_s += static_cast<double>(duration.count()) *
+                              microseconds::period::num /
+                              microseconds::period::den;
+    }
+  }
+  std::cout << "Total running time: " << total_running_time_s
+            << " s, average running time: " << total_running_time_s / imgs
+            << " s/img, total read img time: " << total_imread_time_s
+            << " s, average read time: " << total_imread_time_s / imgs
+            << " s/img" << std::endl;
+  return 0;
+}

+ 190 - 0
examples/meter_reader/deploy/cpp/meter/readvalue.cpp

@@ -0,0 +1,190 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include <iostream>
+#include <vector>
+#include <utility>
+#include <limits>
+#include <cmath>
+#include <chrono>  // NOLINT
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/highgui.hpp>
+#include <opencv2/core/core.hpp>
+
+#include "meter/global.h"
+#include "meter/readvalue.h"
+
+using namespace std::chrono;  // NOLINT
+
+#define SEG_IMAGE_SIZE 512
+#define LINE_HEIGHT 120
+#define LINE_WIDTH 1570
+#define CIRCLE_RADIUS 250
+
+const float pi = 3.1415926536f;
+const int circle_center[] = {256, 256};
+
+
+void creat_line_image(const std::vector<int64_t> &seg_image,
+                      std::vector<unsigned char> *output) {
+  float theta;
+  int rho;
+  int image_x;
+  int image_y;
+
+  for (int row = 0; row < LINE_HEIGHT; row++) {
+    for (int col = 0; col < LINE_WIDTH; col++) {
+      theta = pi * 2 / LINE_WIDTH * (col + 1);
+      rho = CIRCLE_RADIUS - row - 1;
+      image_x = static_cast<int>(circle_center[0] + rho * cos(theta) + 0.5);
+      image_y = static_cast<int>(circle_center[1] - rho * sin(theta) + 0.5);
+      (*output)[row * LINE_WIDTH + col] =
+        seg_image[image_x * SEG_IMAGE_SIZE + image_y];
+    }
+  }
+
+  return;
+}
+
+void convert_1D_data(const std::vector<unsigned char> &line_image,
+                     std::vector<unsigned int> *scale_data,
+                     std::vector<unsigned int> *pointer_data) {
+  for (int col = 0; col < LINE_WIDTH; col++) {
+    (*scale_data)[col] = 0;
+    (*pointer_data)[col] = 0;
+    for (int row = 0; row < LINE_HEIGHT; row++) {
+        if (line_image[row * LINE_WIDTH + col] == 1) {
+            (*pointer_data)[col]++;
+        } else if (line_image[row * LINE_WIDTH + col] == 2) {
+            (*scale_data)[col]++;
+        }
+    }
+  }
+  return;
+}
+
+void scale_mean_filtration(const std::vector<unsigned int> &scale_data,
+                           std::vector<unsigned int> *scale_mean_data) {
+  int sum = 0;
+  float mean = 0;
+  int size = scale_data.size();
+  for (int i = 0; i < size; i++) {
+      sum = sum + scale_data[i];
+  }
+  mean = static_cast<float>(sum) / static_cast<float>(size);
+
+  for (int i = 0; i < size; i++) {
+    if (static_cast<float>(scale_data[i]) >= mean) {
+        (*scale_mean_data)[i] = scale_data[i];
+    }
+  }
+
+  return;
+}
+
+void get_meter_reader(const std::vector<unsigned int> &scale,
+                      const std::vector<unsigned int> &pointer,
+                      READ_RESULT *result) {
+  std::vector<float> scale_location;
+  float one_scale_location = 0;
+  bool scale_flag = 0;
+  unsigned int one_scale_start = 0;
+  unsigned int one_scale_end = 0;
+
+  float pointer_location = 0;
+  bool pointer_flag = 0;
+  unsigned int one_pointer_start = 0;
+  unsigned int one_pointer_end = 0;
+
+  for (int i = 0; i < LINE_WIDTH; i++) {
+    // scale location
+    if (scale[i] > 0 && scale[i+1] > 0) {
+      if (scale_flag == 0) {
+        one_scale_start = i;
+        scale_flag = 1;
+      }
+    }
+    if (scale_flag == 1) {
+      if (scale[i] == 0 && scale[i+1] == 0) {
+          one_scale_end = i - 1;
+          one_scale_location = (one_scale_start + one_scale_end) / 2.;
+          scale_location.push_back(one_scale_location);
+          one_scale_start = 0;
+          one_scale_end = 0;
+          scale_flag = 0;
+      }
+    }
+
+    // pointer location
+    if (pointer[i] > 0 && pointer[i+1] > 0) {
+      if (pointer_flag == 0) {
+        one_pointer_start = i;
+        pointer_flag = 1;
+      }
+    }
+    if (pointer_flag == 1) {
+      if ((pointer[i] == 0) && (pointer[i+1] == 0)) {
+        one_pointer_end = i - 1;
+        pointer_location = (one_pointer_start + one_pointer_end) / 2.;
+        one_pointer_start = 0;
+        one_pointer_end = 0;
+        pointer_flag = 0;
+      }
+    }
+  }
+
+  int scale_num = scale_location.size();
+  result->scale_num = scale_num;
+  result->scales = -1;
+  result->ratio = -1;
+  if (scale_num > 0) {
+    for (int i = 0; i < scale_num - 1; i++) {
+      if (scale_location[i] <= pointer_location &&
+            pointer_location < scale_location[i + 1]) {
+        result->scales = i + 1 +
+          (pointer_location-scale_location[i]) /
+          (scale_location[i+1]-scale_location[i] + 1e-05);
+      }
+    }
+    result->ratio =
+      (pointer_location - scale_location[0]) /
+      (scale_location[scale_num - 1] - scale_location[0] + 1e-05);
+  }
+  return;
+}
+
+void read_process(const std::vector<std::vector<int64_t>> &seg_image,
+                  std::vector<READ_RESULT> *read_results,
+                  const int thread_num) {
+    int read_num = seg_image.size();
+    #pragma omp parallel for num_threads(thread_num)
+    for (int i_read = 0; i_read < read_num; i_read++) {
+        std::vector<unsigned char> line_result(LINE_WIDTH*LINE_HEIGHT, 0);
+        creat_line_image(seg_image[i_read], &line_result);
+
+        std::vector<unsigned int> scale_data(LINE_WIDTH);
+        std::vector<unsigned int> pointer_data(LINE_WIDTH);
+        convert_1D_data(line_result, &scale_data, &pointer_data);
+        std::vector<unsigned int> scale_mean_data(LINE_WIDTH);
+        scale_mean_filtration(scale_data, &scale_mean_data);
+
+        READ_RESULT result;
+        get_meter_reader(scale_mean_data, pointer_data, &result);
+
+        (*read_results)[i_read] = std::move(result);
+    }
+    return;
+}

+ 42 - 0
examples/meter_reader/deploy/cpp/meter/readvalue.h

@@ -0,0 +1,42 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#pragma once
+
+#include <vector>
+
+struct READ_RESULT {
+  int scale_num;
+  float scales;
+  float ratio;
+};
+
+void creat_line_image(const std::vector<int64_t> &seg_image,
+                      std::vector<unsigned char> *output);
+
+void convert_1D_data(const std::vector<unsigned char> &line_image,
+                     std::vector<unsigned int> *scale_data,
+                     std::vector<unsigned int> *pointer_data);
+
+void scale_mean_filtration(const std::vector<unsigned int> &scale_data,
+                           std::vector<unsigned int> *scale_mean_data);
+
+void get_meter_reader(const std::vector<unsigned int> &scale,
+                      const std::vector<unsigned int> &pointer,
+                      READ_RESULT *result);
+
+void read_process(const std::vector<std::vector<int64_t>> &seg_image,
+                  std::vector<READ_RESULT> *read_results,
+                  const int thread_num);

+ 354 - 0
examples/meter_reader/deploy/python/reader_deploy.py

@@ -0,0 +1,354 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import numpy as np
+import math
+import cv2
+import argparse
+
+from paddlex.seg import transforms
+import paddlex as pdx
+
+METER_SHAPE = 512
+CIRCLE_CENTER = [256, 256]
+CIRCLE_RADIUS = 250
+PI = 3.1415926536
+LINE_HEIGHT = 120
+LINE_WIDTH = 1570
+TYPE_THRESHOLD = 40
+METER_CONFIG = [{
+    'scale_value': 25.0 / 50.0,
+    'range': 25.0,
+    'unit': "(MPa)"
+}, {
+    'scale_value': 1.6 / 32.0,
+    'range': 1.6,
+    'unit': "(MPa)"
+}]
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Meter Reader Infering')
+    parser.add_argument(
+        '--detector_dir',
+        dest='detector_dir',
+        help='The directory of models to do detection',
+        type=str)
+    parser.add_argument(
+        '--segmenter_dir',
+        dest='segmenter_dir',
+        help='The directory of models to do segmentation',
+        type=str)
+    parser.add_argument(
+        '--image_dir',
+        dest='image_dir',
+        help='The directory of images to be infered',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--image',
+        dest='image',
+        help='The image to be infered',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--use_camera',
+        dest='use_camera',
+        help='Whether use camera or not',
+        action='store_true')
+    parser.add_argument(
+        '--use_erode',
+        dest='use_erode',
+        help='Whether erode the predicted lable map',
+        action='store_true')
+    parser.add_argument(
+        '--erode_kernel',
+        dest='erode_kernel',
+        help='Erode kernel size',
+        type=int,
+        default=4)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the inference results',
+        type=str,
+        default='./output/result')
+    parser.add_argument(
+        '--score_threshold',
+        dest='score_threshold',
+        help="Detected bbox whose score is lower than this threshlod is filtered",
+        type=float,
+        default=0.5)
+    parser.add_argument(
+        '--seg_batch_size',
+        dest='seg_batch_size',
+        help="Segmentation batch size",
+        type=int,
+        default=2)
+    parser.add_argument(
+        '--seg_thread_num',
+        dest='seg_thread_num',
+        help="Thread number of segmentation preprocess",
+        type=int,
+        default=2)
+
+    return parser.parse_args()
+
+
+def is_pic(img_name):
+    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
+    suffix = img_name.split('.')[-1]
+    if suffix not in valid_suffix:
+        return False
+    return True
+
+
+class MeterReader:
+    def __init__(self, detector_dir, segmenter_dir):
+        if not osp.exists(detector_dir):
+            raise Exception("Model path {} does not exist".format(
+                detector_dir))
+        if not osp.exists(segmenter_dir):
+            raise Exception("Model path {} does not exist".format(
+                segmenter_dir))
+        self.detector = pdx.deploy.Predictor(detector_dir)
+        self.segmenter = pdx.deploy.Predictor(segmenter_dir)
+        # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
+        # here the transform is composed of normalization only.
+        self.seg_transforms = transforms.Compose([transforms.Normalize()])
+
+    def predict(self,
+                im_file,
+                save_dir='./',
+                use_erode=True,
+                erode_kernel=4,
+                score_threshold=0.5,
+                seg_batch_size=2,
+                seg_thread_num=2):
+        if isinstance(im_file, str):
+            im = cv2.imread(im_file).astype('float32')
+        else:
+            im = im_file.copy()
+        # Get detection results
+        det_results = self.detector.predict(im)
+        # Filter bbox whose score is lower than score_threshold
+        filtered_results = list()
+        for res in det_results:
+            if res['score'] > score_threshold:
+                filtered_results.append(res)
+
+        resized_meters = list()
+        for res in filtered_results:
+            # Crop the bbox area
+            xmin, ymin, w, h = res['bbox']
+            xmin = max(0, int(xmin))
+            ymin = max(0, int(ymin))
+            xmax = min(im.shape[1], int(xmin + w - 1))
+            ymax = min(im.shape[0], int(ymin + h - 1))
+            sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
+
+            # Resize the image with shape (METER_SHAPE, METER_SHAPE)
+            meter_shape = sub_image.shape
+            scale_x = float(METER_SHAPE) / float(meter_shape[1])
+            scale_y = float(METER_SHAPE) / float(meter_shape[0])
+            meter_meter = cv2.resize(
+                sub_image,
+                None,
+                None,
+                fx=scale_x,
+                fy=scale_y,
+                interpolation=cv2.INTER_LINEAR)
+            meter_meter = meter_meter.astype('float32')
+            resized_meters.append(meter_meter)
+
+        meter_num = len(resized_meters)
+        seg_results = list()
+        for i in range(0, meter_num, seg_batch_size):
+            im_size = min(meter_num, i + seg_batch_size)
+            meter_images = list()
+            for j in range(i, im_size):
+                meter_images.append(resized_meters[j - i])
+            result = self.segmenter.batch_predict(
+                transforms=self.seg_transforms,
+                img_file_list=meter_images,
+                thread_num=seg_thread_num)
+            if use_erode:
+                kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
+                for i in range(len(seg_results)):
+                    results[i]['label_map'] = cv2.erode(
+                        seg_results[i]['label_map'], kernel)
+            seg_results.extend(result)
+
+        results = list()
+        for i, seg_result in enumerate(seg_results):
+            result = self.read_process(seg_result['label_map'])
+            results.append(result)
+
+        meter_values = list()
+        for i, result in enumerate(results):
+            if result['scale_num'] > TYPE_THRESHOLD:
+                value = result['scales'] * METER_CONFIG[0]['scale_value']
+            else:
+                value = result['scales'] * METER_CONFIG[1]['scale_value']
+            meter_values.append(value)
+            print("-- Meter {} -- result: {} --\n".format(i, value))
+
+        # visualize the results
+        visual_results = list()
+        for i, res in enumerate(filtered_results):
+            # Use `score` to represent the meter value
+            res['score'] = meter_values[i]
+            visual_results.append(res)
+        pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
+
+    def read_process(self, label_maps):
+        # Convert the circular meter into rectangular meter
+        line_images = self.creat_line_image(label_maps)
+        # Convert the 2d meter into 1d meter
+        scale_data, pointer_data = self.convert_1d_data(line_images)
+        # Fliter scale data whose value is lower than the mean value
+        self.scale_mean_filtration(scale_data)
+        # Get scale_num, scales and ratio of meters
+        result = self.get_meter_reader(scale_data, pointer_data)
+        return result
+
+    def creat_line_image(self, meter_image):
+        line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
+        for row in range(LINE_HEIGHT):
+            for col in range(LINE_WIDTH):
+                theta = PI * 2 / LINE_WIDTH * (col + 1)
+                rho = CIRCLE_RADIUS - row - 1
+                x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
+                y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
+                line_image[row, col] = meter_image[x, y]
+        return line_image
+
+    def convert_1d_data(self, meter_image):
+        scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        for col in range(LINE_WIDTH):
+            for row in range(LINE_HEIGHT):
+                if meter_image[row, col] == 1:
+                    pointer_data[col] += 1
+                elif meter_image[row, col] == 2:
+                    scale_data[col] += 1
+        return scale_data, pointer_data
+
+    def scale_mean_filtration(self, scale_data):
+        mean_data = np.mean(scale_data)
+        for col in range(LINE_WIDTH):
+            if scale_data[col] < mean_data:
+                scale_data[col] = 0
+
+    def get_meter_reader(self, scale_data, pointer_data):
+        scale_flag = False
+        pointer_flag = False
+        one_scale_start = 0
+        one_scale_end = 0
+        one_pointer_start = 0
+        one_pointer_end = 0
+        scale_location = list()
+        pointer_location = 0
+        for i in range(LINE_WIDTH - 1):
+            if scale_data[i] > 0 and scale_data[i + 1] > 0:
+                if scale_flag == False:
+                    one_scale_start = i
+                    scale_flag = True
+            if scale_flag:
+                if scale_data[i] == 0 and scale_data[i + 1] == 0:
+                    one_scale_end = i - 1
+                    one_scale_location = (one_scale_start + one_scale_end) / 2
+                    scale_location.append(one_scale_location)
+                    one_scale_start = 0
+                    one_scale_end = 0
+                    scale_flag = False
+            if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
+                if pointer_flag == False:
+                    one_pointer_start = i
+                    pointer_flag = True
+            if pointer_flag:
+                if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
+                    one_pointer_end = i - 1
+                    pointer_location = (
+                        one_pointer_start + one_pointer_end) / 2
+                    one_pointer_start = 0
+                    one_pointer_end = 0
+                    pointer_flag = False
+
+        scale_num = len(scale_location)
+        scales = -1
+        ratio = -1
+        if scale_num > 0:
+            for i in range(scale_num - 1):
+                if scale_location[
+                        i] <= pointer_location and pointer_location < scale_location[
+                            i + 1]:
+                    scales = i + (pointer_location - scale_location[i]) / (
+                        scale_location[i + 1] - scale_location[i] + 1e-05) + 1
+            ratio = (pointer_location - scale_location[0]) / (
+                scale_location[scale_num - 1] - scale_location[0] + 1e-05)
+        result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
+        return result
+
+
+def infer(args):
+    image_lists = list()
+    if args.image is not None:
+        if not osp.exists(args.image):
+            raise Exception("Image {} does not exist.".format(args.image))
+        if not is_pic(args.image):
+            raise Exception("{} is not a picture.".format(args.image))
+        image_lists.append(args.image)
+    elif args.image_dir is not None:
+        if not osp.exists(args.image_dir):
+            raise Exception("Directory {} does not exist.".format(
+                args.image_dir))
+        for im_file in os.listdir(args.image_dir):
+            if not is_pic(im_file):
+                continue
+            im_file = osp.join(args.image_dir, im_file)
+            image_lists.append(im_file)
+
+    meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
+    if len(image_lists) > 0:
+        for im_file in image_lists:
+            meter_reader.predict(im_file, args.save_dir, args.use_erode,
+                                 args.erode_kernel, args.score_threshold,
+                                 args.seg_batch_size, args.seg_thread_num)
+    elif args.with_camera:
+        cap_video = cv2.VideoCapture(0)
+        if not cap_video.isOpened():
+            raise Exception(
+                "Error opening video stream, please make sure the camera is working"
+            )
+
+        while cap_video.isOpened():
+            ret, frame = cap_video.read()
+            if ret:
+                meter_reader.predict(frame, args.save_dir, args.use_erode,
+                                     args.erode_kernel, args.score_threshold,
+                                     args.seg_batch_size, args.seg_thread_num)
+                if cv2.waitKey(1) & 0xFF == ord('q'):
+                    break
+            else:
+                break
+        cap_video.release()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    infer(args)

+ 354 - 0
examples/meter_reader/reader_infer.py

@@ -0,0 +1,354 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import numpy as np
+import math
+import cv2
+import argparse
+
+from paddlex.seg import transforms
+import paddlex as pdx
+
+METER_SHAPE = 512
+CIRCLE_CENTER = [256, 256]
+CIRCLE_RADIUS = 250
+PI = 3.1415926536
+LINE_HEIGHT = 120
+LINE_WIDTH = 1570
+TYPE_THRESHOLD = 40
+METER_CONFIG = [{
+    'scale_value': 25.0 / 50.0,
+    'range': 25.0,
+    'unit': "(MPa)"
+}, {
+    'scale_value': 1.6 / 32.0,
+    'range': 1.6,
+    'unit': "(MPa)"
+}]
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Meter Reader Infering')
+    parser.add_argument(
+        '--detector_dir',
+        dest='detector_dir',
+        help='The directory of models to do detection',
+        type=str)
+    parser.add_argument(
+        '--segmenter_dir',
+        dest='segmenter_dir',
+        help='The directory of models to do segmentation',
+        type=str)
+    parser.add_argument(
+        '--image_dir',
+        dest='image_dir',
+        help='The directory of images to be infered',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--image',
+        dest='image',
+        help='The image to be infered',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--use_camera',
+        dest='use_camera',
+        help='Whether use camera or not',
+        action='store_true')
+    parser.add_argument(
+        '--use_erode',
+        dest='use_erode',
+        help='Whether erode the predicted lable map',
+        action='store_true')
+    parser.add_argument(
+        '--erode_kernel',
+        dest='erode_kernel',
+        help='Erode kernel size',
+        type=int,
+        default=4)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the inference results',
+        type=str,
+        default='./output/result')
+    parser.add_argument(
+        '--score_threshold',
+        dest='score_threshold',
+        help="Detected bbox whose score is lower than this threshlod is filtered",
+        type=float,
+        default=0.5)
+    parser.add_argument(
+        '--seg_batch_size',
+        dest='seg_batch_size',
+        help="Segmentation batch size",
+        type=int,
+        default=2)
+    parser.add_argument(
+        '--seg_thread_num',
+        dest='seg_thread_num',
+        help="Thread number of segmentation preprocess",
+        type=int,
+        default=2)
+
+    return parser.parse_args()
+
+
+def is_pic(img_name):
+    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
+    suffix = img_name.split('.')[-1]
+    if suffix not in valid_suffix:
+        return False
+    return True
+
+
+class MeterReader:
+    def __init__(self, detector_dir, segmenter_dir):
+        if not osp.exists(detector_dir):
+            raise Exception("Model path {} does not exist".format(
+                detector_dir))
+        if not osp.exists(segmenter_dir):
+            raise Exception("Model path {} does not exist".format(
+                segmenter_dir))
+        self.detector = pdx.load_model(detector_dir)
+        self.segmenter = pdx.load_model(segmenter_dir)
+        # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
+        # here the transform is composed of normalization only.
+        self.seg_transforms = transforms.Compose([transforms.Normalize()])
+
+    def predict(self,
+                im_file,
+                save_dir='./',
+                use_erode=True,
+                erode_kernel=4,
+                score_threshold=0.5,
+                seg_batch_size=2,
+                seg_thread_num=2):
+        if isinstance(im_file, str):
+            im = cv2.imread(im_file).astype('float32')
+        else:
+            im = im_file.copy()
+        # Get detection results
+        det_results = self.detector.predict(im)
+        # Filter bbox whose score is lower than score_threshold
+        filtered_results = list()
+        for res in det_results:
+            if res['score'] > score_threshold:
+                filtered_results.append(res)
+
+        resized_meters = list()
+        for res in filtered_results:
+            # Crop the bbox area
+            xmin, ymin, w, h = res['bbox']
+            xmin = max(0, int(xmin))
+            ymin = max(0, int(ymin))
+            xmax = min(im.shape[1], int(xmin + w - 1))
+            ymax = min(im.shape[0], int(ymin + h - 1))
+            sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
+
+            # Resize the image with shape (METER_SHAPE, METER_SHAPE)
+            meter_shape = sub_image.shape
+            scale_x = float(METER_SHAPE) / float(meter_shape[1])
+            scale_y = float(METER_SHAPE) / float(meter_shape[0])
+            meter_meter = cv2.resize(
+                sub_image,
+                None,
+                None,
+                fx=scale_x,
+                fy=scale_y,
+                interpolation=cv2.INTER_LINEAR)
+            meter_meter = meter_meter.astype('float32')
+            resized_meters.append(meter_meter)
+
+        meter_num = len(resized_meters)
+        seg_results = list()
+        for i in range(0, meter_num, seg_batch_size):
+            im_size = min(meter_num, i + seg_batch_size)
+            meter_images = list()
+            for j in range(i, im_size):
+                meter_images.append(resized_meters[j - i])
+            result = self.segmenter.batch_predict(
+                transforms=self.seg_transforms,
+                img_file_list=meter_images,
+                thread_num=seg_thread_num)
+            if use_erode:
+                kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
+                for i in range(len(seg_results)):
+                    results[i]['label_map'] = cv2.erode(
+                        seg_results[i]['label_map'], kernel)
+            seg_results.extend(result)
+
+        results = list()
+        for i, seg_result in enumerate(seg_results):
+            result = self.read_process(seg_result['label_map'])
+            results.append(result)
+
+        meter_values = list()
+        for i, result in enumerate(results):
+            if result['scale_num'] > TYPE_THRESHOLD:
+                value = result['scales'] * METER_CONFIG[0]['scale_value']
+            else:
+                value = result['scales'] * METER_CONFIG[1]['scale_value']
+            meter_values.append(value)
+            print("-- Meter {} -- result: {} --\n".format(i, value))
+
+        # visualize the results
+        visual_results = list()
+        for i, res in enumerate(filtered_results):
+            # Use `score` to represent the meter value
+            res['score'] = meter_values[i]
+            visual_results.append(res)
+        pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
+
+    def read_process(self, label_maps):
+        # Convert the circular meter into rectangular meter
+        line_images = self.creat_line_image(label_maps)
+        # Convert the 2d meter into 1d meter
+        scale_data, pointer_data = self.convert_1d_data(line_images)
+        # Fliter scale data whose value is lower than the mean value
+        self.scale_mean_filtration(scale_data)
+        # Get scale_num, scales and ratio of meters
+        result = self.get_meter_reader(scale_data, pointer_data)
+        return result
+
+    def creat_line_image(self, meter_image):
+        line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
+        for row in range(LINE_HEIGHT):
+            for col in range(LINE_WIDTH):
+                theta = PI * 2 / LINE_WIDTH * (col + 1)
+                rho = CIRCLE_RADIUS - row - 1
+                x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
+                y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
+                line_image[row, col] = meter_image[x, y]
+        return line_image
+
+    def convert_1d_data(self, meter_image):
+        scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        for col in range(LINE_WIDTH):
+            for row in range(LINE_HEIGHT):
+                if meter_image[row, col] == 1:
+                    pointer_data[col] += 1
+                elif meter_image[row, col] == 2:
+                    scale_data[col] += 1
+        return scale_data, pointer_data
+
+    def scale_mean_filtration(self, scale_data):
+        mean_data = np.mean(scale_data)
+        for col in range(LINE_WIDTH):
+            if scale_data[col] < mean_data:
+                scale_data[col] = 0
+
+    def get_meter_reader(self, scale_data, pointer_data):
+        scale_flag = False
+        pointer_flag = False
+        one_scale_start = 0
+        one_scale_end = 0
+        one_pointer_start = 0
+        one_pointer_end = 0
+        scale_location = list()
+        pointer_location = 0
+        for i in range(LINE_WIDTH - 1):
+            if scale_data[i] > 0 and scale_data[i + 1] > 0:
+                if scale_flag == False:
+                    one_scale_start = i
+                    scale_flag = True
+            if scale_flag:
+                if scale_data[i] == 0 and scale_data[i + 1] == 0:
+                    one_scale_end = i - 1
+                    one_scale_location = (one_scale_start + one_scale_end) / 2
+                    scale_location.append(one_scale_location)
+                    one_scale_start = 0
+                    one_scale_end = 0
+                    scale_flag = False
+            if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
+                if pointer_flag == False:
+                    one_pointer_start = i
+                    pointer_flag = True
+            if pointer_flag:
+                if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
+                    one_pointer_end = i - 1
+                    pointer_location = (
+                        one_pointer_start + one_pointer_end) / 2
+                    one_pointer_start = 0
+                    one_pointer_end = 0
+                    pointer_flag = False
+
+        scale_num = len(scale_location)
+        scales = -1
+        ratio = -1
+        if scale_num > 0:
+            for i in range(scale_num - 1):
+                if scale_location[
+                        i] <= pointer_location and pointer_location < scale_location[
+                            i + 1]:
+                    scales = i + (pointer_location - scale_location[i]) / (
+                        scale_location[i + 1] - scale_location[i] + 1e-05) + 1
+            ratio = (pointer_location - scale_location[0]) / (
+                scale_location[scale_num - 1] - scale_location[0] + 1e-05)
+        result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
+        return result
+
+
+def infer(args):
+    image_lists = list()
+    if args.image is not None:
+        if not osp.exists(args.image):
+            raise Exception("Image {} does not exist.".format(args.image))
+        if not is_pic(args.image):
+            raise Exception("{} is not a picture.".format(args.image))
+        image_lists.append(args.image)
+    elif args.image_dir is not None:
+        if not osp.exists(args.image_dir):
+            raise Exception("Directory {} does not exist.".format(
+                args.image_dir))
+        for im_file in os.listdir(args.image_dir):
+            if not is_pic(im_file):
+                continue
+            im_file = osp.join(args.image_dir, im_file)
+            image_lists.append(im_file)
+
+    meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
+    if len(image_lists) > 0:
+        for im_file in image_lists:
+            meter_reader.predict(im_file, args.save_dir, args.use_erode,
+                                 args.erode_kernel, args.score_threshold,
+                                 args.seg_batch_size, args.seg_thread_num)
+    elif args.with_camera:
+        cap_video = cv2.VideoCapture(0)
+        if not cap_video.isOpened():
+            raise Exception(
+                "Error opening video stream, please make sure the camera is working"
+            )
+
+        while cap_video.isOpened():
+            ret, frame = cap_video.read()
+            if ret:
+                meter_reader.predict(frame, args.save_dir, args.use_erode,
+                                     args.erode_kernel, args.score_threshold,
+                                     args.seg_batch_size, args.seg_thread_num)
+                if cv2.waitKey(1) & 0xFF == ord('q'):
+                    break
+            else:
+                break
+        cap_video.release()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    infer(args)

+ 50 - 0
examples/meter_reader/train_detection.py

@@ -0,0 +1,50 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+# 下载和解压表计检测数据集
+meter_det_dataset = 'https://bj.bcebos.com/paddlex/datasets/meter_det.tar.gz'
+pdx.utils.download_and_decompress(meter_det_dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedyolotransforms
+train_transforms = transforms.ComposedYOLOv3Transforms(
+    mode='train', shape=[608, 608])
+eval_transforms = transforms.ComposedYOLOv3Transforms(
+    mode='eval', shape=[608, 608])
+
+# 定义训练和验证所用的数据集
+# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#vocdetection
+train_dataset = pdx.datasets.CocoDetection(
+    data_dir='meter_det/train/',
+    ann_file='meter_det/annotations/instance_train.json',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.CocoDetection(
+    data_dir='meter_det/train/',
+    ann_file='meter_det/annotations/instance_train.json',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标
+# VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/detection.html#yolov3
+num_classes = len(train_dataset.labels)
+model = pdx.det.YOLOv3(
+    num_classes=num_classes, backbone='DarkNet53', label_smooth=True)
+model.train(
+    num_epochs=270,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.001,
+    warmup_steps=4000,
+    lr_decay_epochs=[210, 240],
+    save_dir='output/meter_det',
+    use_vdl=True)

+ 57 - 0
examples/meter_reader/train_segmentation.py

@@ -0,0 +1,57 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 下载和解压表盘分割数据集
+meter_seg_dataset = 'https://bj.bcebos.com/paddlex/datasets/meter_seg.tar.gz'
+pdx.utils.download_and_decompress(meter_seg_dataset, path='./')
+
+# 定义训练和验证时的transforms
+train_transforms = transforms.Compose([
+    transforms.Resize([512, 512]),
+    transforms.RandomHorizontalFlip(prob=0.5),
+    transforms.Normalize(),
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize([512, 512]),
+    transforms.Normalize(),
+])
+# 定义训练和验证所用的数据集
+# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
+train_dataset = pdx.datasets.SegDataset(
+    data_dir='meter_seg/',
+    file_list='meter_paddleseg_414/train.txt',
+    label_list='meter_paddleseg_414/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='meter_paddleseg_414/',
+    file_list='meter_paddleseg_414/val.txt',
+    label_list='meter_paddleseg_414/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标
+# VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+#
+# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#deeplabv3p
+model = pdx.seg.DeepLabv3p(
+    num_classes=len(train_dataset.labels),
+    backbone='Xception65',
+    aspp_with_sep_conv=False)
+model.train(
+    num_epochs=20,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.1,
+    pretrain_weights='COCO',
+    save_interval_epochs=5,
+    save_dir='output/meter_seg',
+    use_vdl=True)

+ 2 - 1
paddlex/cv/datasets/__init__.py

@@ -18,4 +18,5 @@ from .coco import CocoDetection
 from .seg_dataset import SegDataset
 from .easydata_cls import EasyDataCls
 from .easydata_det import EasyDataDet
-from .easydata_seg import EasyDataSeg
+from .easydata_seg import EasyDataSeg
+from .dataset import GenerateMiniBatch

+ 19 - 1
paddlex/cv/datasets/dataset.py

@@ -212,7 +212,25 @@ def GenerateMiniBatch(batch_data):
         padding_im = np.zeros(
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
-        padding_batch.append((padding_im, ) + data[1:])
+        if len(data) > 1:
+            if isinstance(data[1], np.ndarray):
+                padding_label = np.zeros(
+                    (1, max_shape[1], max_shape[2])).astype('int64')
+                _, label_h, label_w = data[1].shape
+                padding_label[:, :label_h, :label_w] = data[1]
+                padding_batch.append((padding_im, padding_label))
+            elif len(data[1]) == 0 or isinstance(
+                    data[1][0],
+                    tuple) and data[1][0][0] in ['resize', 'padding']:
+                if len(data[1]) == 0 or 'padding' not in [
+                        data[1][i][0] for i in range(len(data[1]))
+                ]:
+                    data[1].append(('padding', [im_h, im_w]))
+                padding_batch.append((padding_im, ) + tuple(data[1:]))
+            else:
+                padding_batch.append((padding_im, ) + tuple(data[1:]))
+        else:
+            padding_batch.append((padding_im))
     return padding_batch
 
 

+ 11 - 20
paddlex/cv/models/base.py

@@ -26,6 +26,7 @@ import functools
 import paddlex.utils.logging as logging
 from paddlex.utils import seconds_to_hms
 from paddlex.utils.utils import EarlyStop
+from paddlex.cv.transforms import arrange_transforms
 import paddlex
 from collections import OrderedDict
 from os import path as osp
@@ -97,23 +98,6 @@ class BaseAPI:
                     mode='test')
         self.test_prog = self.test_prog.clone(for_test=True)
 
-    def arrange_transforms(self, transforms, mode='train'):
-        # 给transforms添加arrange操作
-        if self.model_type == 'classifier':
-            arrange_transform = paddlex.cls.transforms.ArrangeClassifier
-        elif self.model_type == 'segmenter':
-            arrange_transform = paddlex.seg.transforms.ArrangeSegmenter
-        elif self.model_type == 'detector':
-            arrange_name = 'Arrange{}'.format(self.__class__.__name__)
-            arrange_transform = getattr(paddlex.det.transforms, arrange_name)
-        else:
-            raise Exception("Unrecognized model type: {}".format(
-                self.model_type))
-        if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
-            transforms.transforms[-1] = arrange_transform(mode=mode)
-        else:
-            transforms.transforms.append(arrange_transform(mode=mode))
-
     def build_train_data_loader(self, dataset, batch_size):
         # 初始化data_loader
         if self.train_data_loader is None:
@@ -135,7 +119,11 @@ class BaseAPI:
                            batch_size=1,
                            batch_num=10,
                            cache_dir="./temp"):
-        self.arrange_transforms(transforms=dataset.transforms, mode='quant')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=dataset.transforms,
+            mode='quant')
         dataset.num_samples = batch_size * batch_num
         try:
             from .slim.post_quantization import PaddleXPostTrainingQuantization
@@ -413,8 +401,11 @@ class BaseAPI:
             from visualdl import LogWriter
             vdl_logdir = osp.join(save_dir, 'vdl_log')
         # 给transform添加arrange操作
-        self.arrange_transforms(
-            transforms=train_dataset.transforms, mode='train')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=train_dataset.transforms,
+            mode='train')
         # 构建train_data_loader
         self.build_train_data_loader(
             dataset=train_dataset, batch_size=train_batch_size)

+ 94 - 21
paddlex/cv/models/classifier.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -17,10 +17,13 @@ import numpy as np
 import time
 import math
 import tqdm
+from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
 import paddlex.utils.logging as logging
 from paddlex.utils import seconds_to_hms
 import paddlex
+from paddlex.cv.transforms import arrange_transforms
+from paddlex.cv.datasets import GenerateMiniBatch
 from collections import OrderedDict
 from .base import BaseAPI
 
@@ -54,7 +57,8 @@ class BaseClassifier(BaseAPI):
             input_shape = [
                 None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
             ]
-            image = fluid.data(dtype='float32', shape=input_shape, name='image')
+            image = fluid.data(
+                dtype='float32', shape=input_shape, name='image')
         else:
             image = fluid.data(
                 dtype='float32', shape=[None, 3, None, None], name='image')
@@ -219,7 +223,8 @@ class BaseClassifier(BaseAPI):
           tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
               包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
         """
-        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(
+            transforms=eval_dataset.transforms, mode='eval')
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
         k = min(5, self.num_classes)
@@ -231,8 +236,9 @@ class BaseClassifier(BaseAPI):
                 self.test_prog).with_data_parallel(
                     share_vars_from=self.parallel_train_prog)
         batch_size_each_gpu = self._get_single_card_bs(batch_size)
-        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, total_steps))
+        logging.info(
+            "Start to evaluating(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data]).astype('float32')
@@ -266,37 +272,104 @@ class BaseClassifier(BaseAPI):
             return metrics, eval_details
         return metrics
 
+    @staticmethod
+    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
+        arrange_transforms(
+            model_type=model_type,
+            class_name=class_name,
+            transforms=transforms,
+            mode='test')
+        pool = ThreadPool(thread_num)
+        batch_data = pool.map(transforms, images)
+        pool.close()
+        pool.join()
+        padding_batch = GenerateMiniBatch(batch_data)
+        im = np.array([data[0] for data in padding_batch])
+
+        return im
+
+    @staticmethod
+    def _postprocess(results, true_topk, labels):
+        preds = list()
+        for i, pred in enumerate(results[0]):
+            pred_label = np.argsort(pred)[::-1][:true_topk]
+            preds.append([{
+                'category_id': l,
+                'category': labels[l],
+                'score': results[0][i][l]
+            } for l in pred_label])
+
+        return preds
+
     def predict(self, img_file, transforms=None, topk=1):
         """预测。
         Args:
-            img_file (str): 预测图像路径。
+            img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组
             transforms (paddlex.cls.transforms): 数据预处理操作。
             topk (int): 预测时前k个最大值。
         Returns:
             list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
             分别对应预测类别id、预测类别标签、预测得分。
         """
+
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
         true_topk = min(self.num_classes, topk)
-        if transforms is not None:
-            self.arrange_transforms(transforms=transforms, mode='test')
-            im = transforms(img_file)
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
         else:
-            self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-            im = self.test_transforms(img_file)
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im = BaseClassifier._preprocess(images, transforms, self.model_type,
+                                        self.__class__.__name__)
+
+        result = self.exe.run(self.test_prog,
+                              feed={'image': im},
+                              fetch_list=list(self.test_outputs.values()),
+                              use_program_cache=True)
+
+        preds = BaseClassifier._postprocess(result, true_topk, self.labels)
+
+        return preds[0]
+
+    def batch_predict(self,
+                      img_file_list,
+                      transforms=None,
+                      topk=1,
+                      thread_num=2):
+        """预测。
+        Args:
+            img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
+                也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
+            transforms (paddlex.cls.transforms): 数据预处理操作。
+            topk (int): 预测时前k个最大值。
+            thread_num (int): 并发执行各图像预处理时的线程数。
+        Returns:
+            list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测列表中,其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
+            分别对应预测类别id、预测类别标签、预测得分。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+        true_topk = min(self.num_classes, topk)
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im = BaseClassifier._preprocess(img_file_list, transforms,
+                                        self.model_type,
+                                        self.__class__.__name__, thread_num)
+
         result = self.exe.run(self.test_prog,
                               feed={'image': im},
                               fetch_list=list(self.test_outputs.values()),
                               use_program_cache=True)
-        pred_label = np.argsort(result[0][0])[::-1][:true_topk]
-        res = [{
-            'category_id': l,
-            'category': self.labels[l],
-            'score': result[0][0][l]
-        } for l in pred_label]
-        return res
+
+        preds = BaseClassifier._postprocess(result, true_topk, self.labels)
+
+        return preds
 
 
 class ResNet18(BaseClassifier):

+ 101 - 32
paddlex/cv/models/deeplabv3p.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -18,9 +18,12 @@ import numpy as np
 import tqdm
 import math
 import cv2
+from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
 import paddlex.utils.logging as logging
 import paddlex
+from paddlex.cv.transforms import arrange_transforms
+from paddlex.cv.datasets import GenerateMiniBatch
 from collections import OrderedDict
 from .base import BaseAPI
 from .utils.seg_eval import ConfusionMatrix
@@ -317,8 +320,11 @@ class DeepLabv3p(BaseAPI):
             tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
                 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
         total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
         conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
         data_generator = eval_dataset.generator(
@@ -378,10 +384,56 @@ class DeepLabv3p(BaseAPI):
             return metrics, eval_details
         return metrics
 
-    def predict(self, im_file, transforms=None):
+    @staticmethod
+    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
+        arrange_transforms(
+            model_type=model_type,
+            class_name=class_name,
+            transforms=transforms,
+            mode='test')
+        pool = ThreadPool(thread_num)
+        batch_data = pool.map(transforms, images)
+        pool.close()
+        pool.join()
+        padding_batch = GenerateMiniBatch(batch_data)
+        im = np.array(
+            [data[0] for data in padding_batch],
+            dtype=padding_batch[0][0].dtype)
+        im_info = [data[1] for data in padding_batch]
+        return im, im_info
+
+    @staticmethod
+    def _postprocess(results, im_info):
+        pred_list = list()
+        logit_list = list()
+        for i, (pred, logit) in enumerate(zip(results[0], results[1])):
+            pred = pred.astype('uint8')
+            pred = np.squeeze(pred).astype('uint8')
+            logit = np.transpose(logit, (1, 2, 0))
+            for info in im_info[i][::-1]:
+                if info[0] == 'resize':
+                    w, h = info[1][1], info[1][0]
+                    pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
+                    logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
+                elif info[0] == 'padding':
+                    w, h = info[1][1], info[1][0]
+                    pred = pred[0:h, 0:w]
+                    logit = logit[0:h, 0:w, :]
+                else:
+                    raise Exception("Unexpected info '{}' in im_info".format(
+                        info[0]))
+            pred_list.append(pred)
+            logit_list.append(logit)
+
+        preds = list()
+        for pred, logit in zip(pred_list, logit_list):
+            preds.append({'label_map': pred, 'score_map': logit})
+        return preds
+
+    def predict(self, img_file, transforms=None):
         """预测。
         Args:
-            img_file(str): 预测图像路径。
+            img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组
             transforms(paddlex.cv.transforms): 数据预处理操作。
 
         Returns:
@@ -391,33 +443,50 @@ class DeepLabv3p(BaseAPI):
 
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
-        if transforms is not None:
-            self.arrange_transforms(transforms=transforms, mode='test')
-            im, im_info = transforms(im_file)
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
         else:
-            self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-            im, im_info = self.test_transforms(im_file)
-        im = np.expand_dims(im, axis=0)
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_info = DeepLabv3p._preprocess(
+            images, transforms, self.model_type, self.__class__.__name__)
+
         result = self.exe.run(self.test_prog,
                               feed={'image': im},
                               fetch_list=list(self.test_outputs.values()),
                               use_program_cache=True)
-        pred = result[0]
-        pred = np.squeeze(pred).astype('uint8')
-        logit = result[1]
-        logit = np.squeeze(logit)
-        logit = np.transpose(logit, (1, 2, 0))
-        for info in im_info[::-1]:
-            if info[0] == 'resize':
-                w, h = info[1][1], info[1][0]
-                pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
-                logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
-            elif info[0] == 'padding':
-                w, h = info[1][1], info[1][0]
-                pred = pred[0:h, 0:w]
-                logit = logit[0:h, 0:w, :]
-            else:
-                raise Exception("Unexpected info '{}' in im_info".format(info[
-                    0]))
-        return {'label_map': pred, 'score_map': logit}
+
+        preds = DeepLabv3p._postprocess(result, im_info)
+        return preds[0]
+
+    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
+        """预测。
+        Args:
+            img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
+                也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
+            transforms(paddlex.cv.transforms): 数据预处理操作。
+
+        Returns:
+            list: 每个元素都为列表,表示各图像的预测结果。各图像的预测结果用字典表示,包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
+                像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
+        """
+
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_info = DeepLabv3p._preprocess(
+            img_file_list, transforms, self.model_type,
+            self.__class__.__name__, thread_num)
+
+        result = self.exe.run(self.test_prog,
+                              feed={'image': im},
+                              fetch_list=list(self.test_outputs.values()),
+                              use_program_cache=True)
+
+        preds = DeepLabv3p._postprocess(result, im_info)
+        return preds

+ 138 - 39
paddlex/cv/models/faster_rcnn.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -16,11 +16,14 @@ from __future__ import absolute_import
 import math
 import tqdm
 import numpy as np
+from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
 import paddlex.utils.logging as logging
 import paddlex
 import os.path as osp
 import copy
+from paddlex.cv.transforms import arrange_transforms
+from paddlex.cv.datasets import GenerateMiniBatch
 from .base import BaseAPI
 from collections import OrderedDict
 from .utils.detection_eval import eval_results, bbox2out
@@ -106,6 +109,24 @@ class FasterRCNN(BaseAPI):
                 freeze_at=2)
         return backbone
 
+    def _generate_mini_batch(self, batch_data):
+        if len(batch_data) == 1:
+            return batch_data
+        width = [data[0].shape[2] for data in batch_data]
+        height = [data[0].shape[1] for data in batch_data]
+        if len(set(width)) == 1 and len(set(height)) == 1:
+            return batch_data
+        max_shape = np.array([data[0].shape for data in batch_data]).max(
+            axis=0)
+        padding_batch = []
+        for data in batch_data:
+            im_c, im_h, im_w = data[0].shape[:]
+            padding_im = np.zeros(
+                (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
+            padding_im[:, :im_h, :im_w] = data[0]
+            padding_batch.append((padding_im, ) + data[1:])
+        return padding_batch
+
     def build_net(self, mode='train'):
         train_pre_nms_top_n = 2000 if self.with_fpn else 12000
         test_pre_nms_top_n = 1000 if self.with_fpn else 6000
@@ -291,7 +312,8 @@ class FasterRCNN(BaseAPI):
                 eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
                 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
         """
-        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(
+            transforms=eval_dataset.transforms, mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric
@@ -310,12 +332,14 @@ class FasterRCNN(BaseAPI):
             logging.warning(
                 "Faster RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
             )
-        dataset = eval_dataset.generator(batch_size=batch_size, drop_last=False)
+        dataset = eval_dataset.generator(
+            batch_size=batch_size, drop_last=False)
 
         total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
         results = list()
-        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, total_steps))
+        logging.info(
+            "Start to evaluating(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(enumerate(dataset()), total=total_steps):
             images = np.array([d[0] for d in data]).astype('float32')
             im_infos = np.array([d[1] for d in data]).astype('float32')
@@ -364,11 +388,49 @@ class FasterRCNN(BaseAPI):
             return metrics, eval_details
         return metrics
 
+    @staticmethod
+    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
+        arrange_transforms(
+            model_type=model_type,
+            class_name=class_name,
+            transforms=transforms,
+            mode='test')
+        pool = ThreadPool(thread_num)
+        batch_data = pool.map(transforms, images)
+        pool.close()
+        pool.join()
+        padding_batch = GenerateMiniBatch(batch_data)
+        im = np.array([data[0] for data in padding_batch])
+        im_resize_info = np.array([data[1] for data in padding_batch])
+        im_shape = np.array([data[2] for data in padding_batch])
+
+        return im, im_resize_info, im_shape
+
+    @staticmethod
+    def _postprocess(results, test_outputs_keys, batch_size, num_classes,
+                     labels):
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        clsid2catid = dict({i: i for i in range(num_classes)})
+        xywh_results = bbox2out([res], clsid2catid)
+        preds = [[] for i in range(batch_size)]
+        for xywh_res in xywh_results:
+            image_id = xywh_res['image_id']
+            del xywh_res['image_id']
+            xywh_res['category'] = labels[xywh_res['category_id']]
+            preds[image_id].append(xywh_res)
+
+        return preds
+
     def predict(self, img_file, transforms=None):
         """预测。
 
         Args:
-            img_file (str): 预测图像路径。
+            img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组
             transforms (paddlex.det.transforms): 数据预处理操作。
 
         Returns:
@@ -378,35 +440,72 @@ class FasterRCNN(BaseAPI):
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
-        if transforms is not None:
-            self.arrange_transforms(transforms=transforms, mode='test')
-            im, im_resize_info, im_shape = transforms(img_file)
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
         else:
-            self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-            im, im_resize_info, im_shape = self.test_transforms(img_file)
-        im = np.expand_dims(im, axis=0)
-        im_resize_info = np.expand_dims(im_resize_info, axis=0)
-        im_shape = np.expand_dims(im_shape, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={
-                                   'image': im,
-                                   'im_info': im_resize_info,
-                                   'im_shape': im_shape
-                               },
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(self.test_outputs.keys()), outputs)
-        }
-        res['im_id'] = (np.array([[0]]).astype('int32'), [])
-        clsid2catid = dict({i: i for i in range(self.num_classes)})
-        xywh_results = bbox2out([res], clsid2catid)
-        results = list()
-        for xywh_res in xywh_results:
-            del xywh_res['image_id']
-            xywh_res['category'] = self.labels[xywh_res['category_id']]
-            results.append(xywh_res)
-        return results
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_resize_info, im_shape = FasterRCNN._preprocess(
+            images, transforms, self.model_type, self.__class__.__name__)
+
+        result = self.exe.run(self.test_prog,
+                              feed={
+                                  'image': im,
+                                  'im_info': im_resize_info,
+                                  'im_shape': im_shape
+                              },
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = FasterRCNN._postprocess(result,
+                                        list(self.test_outputs.keys()),
+                                        len(images), self.num_classes,
+                                        self.labels)
+
+        return preds[0]
+
+    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
+        """预测。
+
+        Args:
+            img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
+                也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
+            transforms (paddlex.det.transforms): 数据预处理操作。
+            thread_num (int): 并发执行各图像预处理时的线程数。
+
+        Returns:
+            list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
+              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
+              预测框得分组成。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_resize_info, im_shape = FasterRCNN._preprocess(
+            img_file_list, transforms, self.model_type,
+            self.__class__.__name__, thread_num)
+
+        result = self.exe.run(self.test_prog,
+                              feed={
+                                  'image': im,
+                                  'im_info': im_resize_info,
+                                  'im_shape': im_shape
+                              },
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = FasterRCNN._postprocess(result,
+                                        list(self.test_outputs.keys()),
+                                        len(img_file_list), self.num_classes,
+                                        self.labels)
+
+        return preds

+ 3 - 66
paddlex/cv/models/load_model.py

@@ -21,6 +21,7 @@ import paddle.fluid as fluid
 from paddle.fluid.framework import Parameter
 import paddlex
 import paddlex.utils.logging as logging
+from paddlex.cv.transforms import build_transforms, build_transforms_v1
 
 
 def load_model(model_dir, fixed_input_shape=None):
@@ -97,8 +98,8 @@ def load_model(model_dir, fixed_input_shape=None):
                 model.model_type, info['Transforms'], info['BatchTransforms'])
             model.eval_transforms = copy.deepcopy(model.test_transforms)
         else:
-            model.test_transforms = build_transforms(model.model_type,
-                                                     info['Transforms'], to_rgb)
+            model.test_transforms = build_transforms(
+                model.model_type, info['Transforms'], to_rgb)
             model.eval_transforms = copy.deepcopy(model.test_transforms)
 
     if '_Attributes' in info:
@@ -124,67 +125,3 @@ def fix_input_shape(info, fixed_input_shape=None):
             padding['Padding']['target_size'] = list(fixed_input_shape)
             info['Transforms'].append(resize)
             info['Transforms'].append(padding)
-
-
-def build_transforms(model_type, transforms_info, to_rgb=True):
-    if model_type == "classifier":
-        import paddlex.cv.transforms.cls_transforms as T
-    elif model_type == "detector":
-        import paddlex.cv.transforms.det_transforms as T
-    elif model_type == "segmenter":
-        import paddlex.cv.transforms.seg_transforms as T
-    transforms = list()
-    for op_info in transforms_info:
-        op_name = list(op_info.keys())[0]
-        op_attr = op_info[op_name]
-        if not hasattr(T, op_name):
-            raise Exception(
-                "There's no operator named '{}' in transforms of {}".format(
-                    op_name, model_type))
-        transforms.append(getattr(T, op_name)(**op_attr))
-    eval_transforms = T.Compose(transforms)
-    eval_transforms.to_rgb = to_rgb
-    return eval_transforms
-
-
-def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
-    """ 老版本模型加载,仅支持PaddleX前端导出的模型
-    """
-    logging.debug("Use build_transforms_v1 to reconstruct transforms")
-    if model_type == "classifier":
-        import paddlex.cv.transforms.cls_transforms as T
-    elif model_type == "detector":
-        import paddlex.cv.transforms.det_transforms as T
-    elif model_type == "segmenter":
-        import paddlex.cv.transforms.seg_transforms as T
-    transforms = list()
-    for op_info in transforms_info:
-        op_name = op_info[0]
-        op_attr = op_info[1]
-        if op_name == 'DecodeImage':
-            continue
-        if op_name == 'Permute':
-            continue
-        if op_name == 'ResizeByShort':
-            op_attr_new = dict()
-            if 'short_size' in op_attr:
-                op_attr_new['short_size'] = op_attr['short_size']
-            else:
-                op_attr_new['short_size'] = op_attr['target_size']
-            op_attr_new['max_size'] = op_attr.get('max_size', -1)
-            op_attr = op_attr_new
-        if op_name.startswith('Arrange'):
-            continue
-        if not hasattr(T, op_name):
-            raise Exception(
-                "There's no operator named '{}' in transforms of {}".format(
-                    op_name, model_type))
-        transforms.append(getattr(T, op_name)(**op_attr))
-    if model_type == "detector" and len(batch_transforms_info) > 0:
-        op_name = batch_transforms_info[0][0]
-        op_attr = batch_transforms_info[0][1]
-        assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
-        padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
-        transforms.append(padding)
-    eval_transforms = T.Compose(transforms)
-    return eval_transforms

+ 106 - 45
paddlex/cv/models/mask_rcnn.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -16,6 +16,7 @@ from __future__ import absolute_import
 import math
 import tqdm
 import numpy as np
+from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
 import paddlex.utils.logging as logging
 import paddlex
@@ -253,7 +254,8 @@ class MaskRCNN(FasterRCNN):
                 预测框坐标、预测框得分;'mask',对应元素预测区域结果列表,每个预测结果由图像id、
                 预测区域类别id、预测区域坐标、预测区域得分;’gt‘:真实标注框和标注区域相关信息。
         """
-        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(
+            transforms=eval_dataset.transforms, mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric
@@ -274,8 +276,9 @@ class MaskRCNN(FasterRCNN):
 
         total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
         results = list()
-        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, total_steps))
+        logging.info(
+            "Start to evaluating(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data]).astype('float32')
@@ -317,7 +320,8 @@ class MaskRCNN(FasterRCNN):
                     zip(['bbox_map', 'segm_map'],
                         [ap_stats[0][1], ap_stats[1][1]]))
             else:
-                metrics = OrderedDict(zip(['bbox_map', 'segm_map'], [0.0, 0.0]))
+                metrics = OrderedDict(
+                    zip(['bbox_map', 'segm_map'], [0.0, 0.0]))
         elif metric == 'COCO':
             if isinstance(ap_stats[0], np.ndarray) and isinstance(ap_stats[1],
                                                                   np.ndarray):
@@ -331,55 +335,112 @@ class MaskRCNN(FasterRCNN):
             return metrics, eval_details
         return metrics
 
+    @staticmethod
+    def _postprocess(results, im_shape, test_outputs_keys, batch_size,
+                     num_classes, mask_head_resolution, labels):
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        res['im_shape'] = (np.array(im_shape), [])
+        clsid2catid = dict({i: i for i in range(num_classes)})
+        xywh_results = bbox2out([res], clsid2catid)
+        segm_results = mask2out([res], clsid2catid, mask_head_resolution)
+        preds = [[] for i in range(batch_size)]
+        import pycocotools.mask as mask_util
+        for index, xywh_res in enumerate(xywh_results):
+            image_id = xywh_res['image_id']
+            del xywh_res['image_id']
+            xywh_res['mask'] = mask_util.decode(segm_results[index][
+                'segmentation'])
+            xywh_res['category'] = labels[xywh_res['category_id']]
+            preds[image_id].append(xywh_res)
+
+        return preds
+
     def predict(self, img_file, transforms=None):
         """预测。
 
         Args:
-            img_file (str): 预测图像路径。
+            img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组
             transforms (paddlex.det.transforms): 数据预处理操作。
 
         Returns:
-            dict: 预测结果列表,每个预测结果由预测框类别标签、预测框类别名称、
+            lict: 预测结果列表,每个预测结果由预测框类别标签、预测框类别名称、
                   预测框坐标(坐标格式为[xmin, ymin, w, h])、
                   原图大小的预测二值图(1表示预测框类别,0表示背景类)、
                   预测框得分组成。
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
-        if transforms is not None:
-            self.arrange_transforms(transforms=transforms, mode='test')
-            im, im_resize_info, im_shape = transforms(img_file)
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
         else:
-            self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-            im, im_resize_info, im_shape = self.test_transforms(img_file)
-        im = np.expand_dims(im, axis=0)
-        im_resize_info = np.expand_dims(im_resize_info, axis=0)
-        im_shape = np.expand_dims(im_shape, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={
-                                   'image': im,
-                                   'im_info': im_resize_info,
-                                   'im_shape': im_shape
-                               },
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(self.test_outputs.keys()), outputs)
-        }
-        res['im_id'] = (np.array([[0]]).astype('int32'), [])
-        res['im_shape'] = (np.array(im_shape), [])
-        clsid2catid = dict({i: i for i in range(self.num_classes)})
-        xywh_results = bbox2out([res], clsid2catid)
-        segm_results = mask2out([res], clsid2catid, self.mask_head_resolution)
-        results = list()
-        import pycocotools.mask as mask_util
-        for index, xywh_res in enumerate(xywh_results):
-            del xywh_res['image_id']
-            xywh_res['mask'] = mask_util.decode(segm_results[index][
-                'segmentation'])
-            xywh_res['category'] = self.labels[xywh_res['category_id']]
-            results.append(xywh_res)
-        return results
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_resize_info, im_shape = FasterRCNN._preprocess(
+            images, transforms, self.model_type, self.__class__.__name__)
+
+        result = self.exe.run(self.test_prog,
+                              feed={
+                                  'image': im,
+                                  'im_info': im_resize_info,
+                                  'im_shape': im_shape
+                              },
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = MaskRCNN._postprocess(result, im_shape,
+                                      list(self.test_outputs.keys()),
+                                      len(images), self.num_classes,
+                                      self.mask_head_resolution, self.labels)
+
+        return preds[0]
+
+    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
+        """预测。
+
+        Args:
+            img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
+                也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
+            transforms (paddlex.det.transforms): 数据预处理操作。
+            thread_num (int): 并发执行各图像预处理时的线程数。
+        Returns:
+            dict: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、预测框类别名称、
+                  预测框坐标(坐标格式为[xmin, ymin, w, h])、
+                  原图大小的预测二值图(1表示预测框类别,0表示背景类)、
+                  预测框得分组成。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_resize_info, im_shape = FasterRCNN._preprocess(
+            img_file_list, transforms, self.model_type,
+            self.__class__.__name__, thread_num)
+
+        result = self.exe.run(self.test_prog,
+                              feed={
+                                  'image': im,
+                                  'im_info': im_resize_info,
+                                  'im_shape': im_shape
+                              },
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = MaskRCNN._postprocess(result, im_shape,
+                                      list(self.test_outputs.keys()),
+                                      len(img_file_list), self.num_classes,
+                                      self.mask_head_resolution, self.labels)
+
+        return preds

+ 2 - 0
paddlex/cv/models/utils/pretrain_weights.py

@@ -65,6 +65,8 @@ image_pretrain = {
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
     'HRNet_W40':
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
+    'HRNet_W44':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W44_C_pretrained.tar',
     'HRNet_W48':
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
     'HRNet_W60':

+ 110 - 35
paddlex/cv/models/yolo_v3.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -17,13 +17,16 @@ import math
 import tqdm
 import os.path as osp
 import numpy as np
+from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
 import paddlex.utils.logging as logging
 import paddlex
+import copy
+from paddlex.cv.transforms import arrange_transforms
+from paddlex.cv.datasets import GenerateMiniBatch
 from .base import BaseAPI
 from collections import OrderedDict
 from .utils.detection_eval import eval_results, bbox2out
-import copy
 
 
 class YOLOv3(BaseAPI):
@@ -286,7 +289,8 @@ class YOLOv3(BaseAPI):
                 eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
                 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
         """
-        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(
+            transforms=eval_dataset.transforms, mode='eval')
         if metric is None:
             if hasattr(self, 'metric') and self.metric is not None:
                 metric = self.metric
@@ -306,8 +310,9 @@ class YOLOv3(BaseAPI):
 
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
-        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, total_steps))
+        logging.info(
+            "Start to evaluating(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data])
@@ -343,11 +348,50 @@ class YOLOv3(BaseAPI):
             return evaluate_metrics, eval_details
         return evaluate_metrics
 
+    @staticmethod
+    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
+        arrange_transforms(
+            model_type=model_type,
+            class_name=class_name,
+            transforms=transforms,
+            mode='test')
+        pool = ThreadPool(thread_num)
+        batch_data = pool.map(transforms, images)
+        pool.close()
+        pool.join()
+        padding_batch = GenerateMiniBatch(batch_data)
+        im = np.array(
+            [data[0] for data in padding_batch],
+            dtype=padding_batch[0][0].dtype)
+        im_size = np.array([data[1] for data in padding_batch], dtype=np.int32)
+
+        return im, im_size
+
+    @staticmethod
+    def _postprocess(results, test_outputs_keys, batch_size, num_classes,
+                     labels):
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [[]])
+        clsid2catid = dict({i: i for i in range(num_classes)})
+        xywh_results = bbox2out([res], clsid2catid)
+        preds = [[] for i in range(batch_size)]
+        for xywh_res in xywh_results:
+            image_id = xywh_res['image_id']
+            del xywh_res['image_id']
+            xywh_res['category'] = labels[xywh_res['category_id']]
+            preds[image_id].append(xywh_res)
+
+        return preds
+
     def predict(self, img_file, transforms=None):
         """预测。
 
         Args:
-            img_file (str): 预测图像路径。
+            img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组
             transforms (paddlex.det.transforms): 数据预处理操作。
 
         Returns:
@@ -357,31 +401,62 @@ class YOLOv3(BaseAPI):
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
-        if transforms is not None:
-            self.arrange_transforms(transforms=transforms, mode='test')
-            im, im_size = transforms(img_file)
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
         else:
-            self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-            im, im_size = self.test_transforms(img_file)
-        im = np.expand_dims(im, axis=0)
-        im_size = np.expand_dims(im_size, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={'image': im,
-                                     'im_size': im_size},
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(self.test_outputs.keys()), outputs)
-        }
-        res['im_id'] = (np.array([[0]]).astype('int32'), [])
-        clsid2catid = dict({i: i for i in range(self.num_classes)})
-        xywh_results = bbox2out([res], clsid2catid)
-        results = list()
-        for xywh_res in xywh_results:
-            del xywh_res['image_id']
-            xywh_res['category'] = self.labels[xywh_res['category_id']]
-            results.append(xywh_res)
-        return results
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_size = YOLOv3._preprocess(images, transforms, self.model_type,
+                                         self.__class__.__name__)
+
+        result = self.exe.run(self.test_prog,
+                              feed={'image': im,
+                                    'im_size': im_size},
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = YOLOv3._postprocess(result,
+                                    list(self.test_outputs.keys()),
+                                    len(images), self.num_classes, self.labels)
+        return preds[0]
+
+    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
+        """预测。
+
+        Args:
+            img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
+                且类型为float32且为BGR格式的数组。
+            transforms (paddlex.det.transforms): 数据预处理操作。
+            thread_num (int): 并发执行各图像预处理时的线程数。
+        Returns:
+            list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
+              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
+              预测框得分组成。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_size = YOLOv3._preprocess(img_file_list, transforms,
+                                         self.model_type,
+                                         self.__class__.__name__, thread_num)
+
+        result = self.exe.run(self.test_prog,
+                              feed={'image': im,
+                                    'im_size': im_size},
+                              fetch_list=list(self.test_outputs.values()),
+                              return_numpy=False,
+                              use_program_cache=True)
+
+        preds = YOLOv3._postprocess(result,
+                                    list(self.test_outputs.keys()),
+                                    len(img_file_list), self.num_classes,
+                                    self.labels)
+        return preds

+ 81 - 0
paddlex/cv/transforms/__init__.py

@@ -15,3 +15,84 @@
 from . import cls_transforms
 from . import det_transforms
 from . import seg_transforms
+
+
+def build_transforms(model_type, transforms_info, to_rgb=True):
+    if model_type == "classifier":
+        from . import cls_transforms as T
+    elif model_type == "detector":
+        from . import det_transforms as T
+    elif model_type == "segmenter":
+        from . import seg_transforms as T
+    transforms = list()
+    for op_info in transforms_info:
+        op_name = list(op_info.keys())[0]
+        op_attr = op_info[op_name]
+        if not hasattr(T, op_name):
+            raise Exception(
+                "There's no operator named '{}' in transforms of {}".format(
+                    op_name, model_type))
+        transforms.append(getattr(T, op_name)(**op_attr))
+    eval_transforms = T.Compose(transforms)
+    eval_transforms.to_rgb = to_rgb
+    return eval_transforms
+
+
+def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
+    """ 老版本模型加载,仅支持PaddleX前端导出的模型
+    """
+    logging.debug("Use build_transforms_v1 to reconstruct transforms")
+    if model_type == "classifier":
+        from . import cls_transforms as T
+    elif model_type == "detector":
+        from . import det_transforms as T
+    elif model_type == "segmenter":
+        from . import seg_transforms as T
+    transforms = list()
+    for op_info in transforms_info:
+        op_name = op_info[0]
+        op_attr = op_info[1]
+        if op_name == 'DecodeImage':
+            continue
+        if op_name == 'Permute':
+            continue
+        if op_name == 'ResizeByShort':
+            op_attr_new = dict()
+            if 'short_size' in op_attr:
+                op_attr_new['short_size'] = op_attr['short_size']
+            else:
+                op_attr_new['short_size'] = op_attr['target_size']
+            op_attr_new['max_size'] = op_attr.get('max_size', -1)
+            op_attr = op_attr_new
+        if op_name.startswith('Arrange'):
+            continue
+        if not hasattr(T, op_name):
+            raise Exception(
+                "There's no operator named '{}' in transforms of {}".format(
+                    op_name, model_type))
+        transforms.append(getattr(T, op_name)(**op_attr))
+    if model_type == "detector" and len(batch_transforms_info) > 0:
+        op_name = batch_transforms_info[0][0]
+        op_attr = batch_transforms_info[0][1]
+        assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
+        padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
+        transforms.append(padding)
+    eval_transforms = T.Compose(transforms)
+    return eval_transforms
+
+
+def arrange_transforms(model_type, class_name, transforms, mode='train'):
+    # 给transforms添加arrange操作
+    if model_type == 'classifier':
+        arrange_transform = cls_transforms.ArrangeClassifier
+    elif model_type == 'segmenter':
+        arrange_transform = seg_transforms.ArrangeSegmenter
+    elif model_type == 'detector':
+        arrange_name = 'Arrange{}'.format(class_name)
+        arrange_transform = getattr(det_transforms, arrange_name)
+    else:
+        raise Exception("Unrecognized model type: {}".format(self.model_type))
+    if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
+        transforms.transforms[-1] = arrange_transform(mode=mode)
+    else:
+        transforms.transforms.append(arrange_transform(mode=mode))

+ 11 - 5
paddlex/cv/transforms/det_transforms.py

@@ -222,13 +222,15 @@ class ResizeByShort(DetTransform):
         im_short_size = min(im.shape[0], im.shape[1])
         im_long_size = max(im.shape[0], im.shape[1])
         scale = float(self.short_size) / im_short_size
-        if self.max_size > 0 and np.round(scale * im_long_size) > self.max_size:
+        if self.max_size > 0 and np.round(scale *
+                                          im_long_size) > self.max_size:
             scale = float(self.max_size) / float(im_long_size)
         resized_width = int(round(im.shape[1] * scale))
         resized_height = int(round(im.shape[0] * scale))
         im_resize_info = [resized_height, resized_width, scale]
         im = cv2.resize(
-            im, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
+            im, (resized_width, resized_height),
+            interpolation=cv2.INTER_LINEAR)
         im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32)
         if label_info is None:
             return (im, im_info)
@@ -268,7 +270,8 @@ class Padding(DetTransform):
                 if not isinstance(target_size, tuple) and not isinstance(
                         target_size, list):
                     raise TypeError(
-                        "Padding: Type of target_size must in (int|list|tuple).")
+                        "Padding: Type of target_size must in (int|list|tuple)."
+                    )
                 elif len(target_size) != 2:
                     raise ValueError(
                         "Padding: Length of target_size must equal 2.")
@@ -453,7 +456,8 @@ class RandomHorizontalFlip(DetTransform):
             ValueError: 数据长度不匹配。
         """
         if not isinstance(im, np.ndarray):
-            raise TypeError("RandomHorizontalFlip: image is not a numpy array.")
+            raise TypeError(
+                "RandomHorizontalFlip: image is not a numpy array.")
         if len(im.shape) != 3:
             raise ValueError(
                 "RandomHorizontalFlip: image is not 3-dimensional.")
@@ -783,7 +787,9 @@ class RandomExpand(DetTransform):
         fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
     """
 
-    def __init__(self, ratio=4., prob=0.5,
+    def __init__(self,
+                 ratio=4.,
+                 prob=0.5,
                  fill_value=[123.675, 116.28, 103.53]):
         super(RandomExpand, self).__init__()
         assert ratio > 1.01, "expand ratio must be larger than 1.01"

+ 75 - 118
paddlex/deploy.py

@@ -18,6 +18,8 @@ import numpy as np
 import yaml
 import paddlex
 import paddle.fluid as fluid
+from paddlex.cv.transforms import build_transforms
+from paddlex.cv.models import BaseClassifier, YOLOv3, FasterRCNN, MaskRCNN, DeepLabv3p
 
 
 class Predictor:
@@ -68,8 +70,8 @@ class Predictor:
             to_rgb = True
         else:
             to_rgb = False
-        self.transforms = self.build_transforms(self.info['Transforms'],
-                                                to_rgb)
+        self.transforms = build_transforms(self.model_type,
+                                           self.info['Transforms'], to_rgb)
         self.predictor = self.create_predictor(
             use_gpu, gpu_id, use_mkl, use_trt, use_glog, memory_optimize)
 
@@ -105,77 +107,73 @@ class Predictor:
         predictor = fluid.core.create_paddle_predictor(config)
         return predictor
 
-    def build_transforms(self, transforms_info, to_rgb=True):
-        if self.model_type == "classifier":
-            from paddlex.cls import transforms
-        elif self.model_type == "detector":
-            from paddlex.det import transforms
-        elif self.model_type == "segmenter":
-            from paddlex.seg import transforms
-        op_list = list()
-        for op_info in transforms_info:
-            op_name = list(op_info.keys())[0]
-            op_attr = op_info[op_name]
-            if not hasattr(transforms, op_name):
-                raise Exception(
-                    "There's no operator named '{}' in transforms of {}".
-                    format(op_name, self.model_type))
-            op_list.append(getattr(transforms, op_name)(**op_attr))
-        eval_transforms = transforms.Compose(op_list)
-        if hasattr(eval_transforms, 'to_rgb'):
-            eval_transforms.to_rgb = to_rgb
-        self.arrange_transforms(eval_transforms)
-        return eval_transforms
-
-    def arrange_transforms(self, transforms):
-        if self.model_type == 'classifier':
-            arrange_transform = paddlex.cls.transforms.ArrangeClassifier
-        elif self.model_type == 'segmenter':
-            arrange_transform = paddlex.seg.transforms.ArrangeSegmenter
-        elif self.model_type == 'detector':
-            arrange_name = 'Arrange{}'.format(self.model_name)
-            arrange_transform = getattr(paddlex.det.transforms, arrange_name)
-        else:
-            raise Exception("Unrecognized model type: {}".format(
-                self.model_type))
-        if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
-            transforms.transforms[-1] = arrange_transform(mode='test')
-        else:
-            transforms.transforms.append(arrange_transform(mode='test'))
-
-    def preprocess(self, image):
+    def preprocess(self, image, thread_num=1):
         """ 对图像做预处理
 
             Args:
-                image(str|np.ndarray): 图片路径或np.ndarray,如为后者,要求是BGR格式
+                image(str|np.ndarray): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+                    或者是对数(元)组中的图像同时进行预测,数组中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
+                    且类型为float32且为BGR格式的数组。
         """
         res = dict()
         if self.model_type == "classifier":
-            im, = self.transforms(image)
-            im = np.expand_dims(im, axis=0).copy()
+            im = BaseClassifier._preprocess(
+                image,
+                self.transforms,
+                self.model_type,
+                self.model_name,
+                thread_num=thread_num)
             res['image'] = im
         elif self.model_type == "detector":
             if self.model_name == "YOLOv3":
-                im, im_shape = self.transforms(image)
-                im = np.expand_dims(im, axis=0).copy()
-                im_shape = np.expand_dims(im_shape, axis=0).copy()
+                im, im_size = YOLOv3._preprocess(
+                    image,
+                    self.transforms,
+                    self.model_type,
+                    self.model_name,
+                    thread_num=thread_num)
                 res['image'] = im
-                res['im_size'] = im_shape
+                res['im_size'] = im_size
             if self.model_name.count('RCNN') > 0:
-                im, im_resize_info, im_shape = self.transforms(image)
-                im = np.expand_dims(im, axis=0).copy()
-                im_resize_info = np.expand_dims(im_resize_info, axis=0).copy()
-                im_shape = np.expand_dims(im_shape, axis=0).copy()
+                im, im_resize_info, im_shape = FasterRCNN._preprocess(
+                    image,
+                    self.transforms,
+                    self.model_type,
+                    self.model_name,
+                    thread_num=thread_num)
                 res['image'] = im
                 res['im_info'] = im_resize_info
                 res['im_shape'] = im_shape
         elif self.model_type == "segmenter":
-            im, im_info = self.transforms(image)
-            im = np.expand_dims(im, axis=0).copy()
+            im, im_imfo = DeepLabv3p._preprocess(
+                image,
+                self.transforms,
+                self.model_type,
+                self.model_name,
+                thread_num=thread_num)
             res['image'] = im
             res['im_info'] = im_info
         return res
 
+    def postprocess(self, results, topk=1, batch_size=1, im_shape=None):
+        if self.model_type == "classifier":
+            true_topk = min(self.num_classes, topk)
+            preds = BaseClassifier._postprocess(results, true_topk,
+                                                self.labels)
+        elif self.model_type == "detector":
+            if self.model_name == "YOLOv3":
+                preds = YOLOv3._postprocess(results, ['bbox'], batch_size,
+                                            self.num_classes, self.labels)
+            elif self.model_name == "FasterRCNN":
+                preds = FasterRCNN._postprocess(results, ['bbox'], batch_size,
+                                                self.num_classes, self.labels)
+            elif self.model_name == "MaskRCNN":
+                preds = MaskRCNN._postprocess(
+                    results, ['bbox', 'mask'], batch_size, self.num_classes,
+                    self.mask_head_resolution, self.labels)
+
+        return preds
+
     def raw_predict(self, inputs):
         """ 接受预处理过后的数据进行预测
 
@@ -196,79 +194,38 @@ class Predictor:
             output_results.append(output_tensor.copy_to_cpu())
         return output_results
 
-    def classifier_postprocess(self, preds, topk=1):
-        """ 对分类模型的预测结果做后处理
-        """
-        true_topk = min(self.num_classes, topk)
-        pred_label = np.argsort(preds[0][0])[::-1][:true_topk]
-        result = [{
-            'category_id': l,
-            'category': self.labels[l],
-            'score': preds[0][0, l],
-        } for l in pred_label]
-        return result
+    def predict(self, image, topk=1):
+        """ 图片预测
 
-    def segmenter_postprocess(self, preds, preprocessed_inputs):
-        """ 对语义分割结果做后处理
+            Args:
+                image(str|np.ndarray|list|tuple): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+                    或者是对数(元)组中的图像同时进行预测,数组中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
+                    且类型为float32且为BGR格式的数组。
+                topk(int): 分类预测时使用,表示预测前topk的结果
         """
-        label_map = np.squeeze(preds[0]).astype('uint8')
-        score_map = np.squeeze(preds[1])
-        score_map = np.transpose(score_map, (1, 2, 0))
-        im_info = preprocessed_inputs['im_info']
-        for info in im_info[::-1]:
-            if info[0] == 'resize':
-                w, h = info[1][1], info[1][0]
-                label_map = cv2.resize(label_map, (w, h), cv2.INTER_NEAREST)
-                score_map = cv2.resize(score_map, (w, h), cv2.INTER_LINEAR)
-            elif info[0] == 'padding':
-                w, h = info[1][1], info[1][0]
-                label_map = label_map[0:h, 0:w]
-                score_map = score_map[0:h, 0:w, :]
-            else:
-                raise Exception("Unexpected info '{}' in im_info".format(info[
-                    0]))
-        return {'label_map': label_map, 'score_map': score_map}
+        preprocessed_input = self.preprocess([image])
+        model_pred = self.raw_predict(preprocessed_input)
+        im_shape = None if 'im_shape' in preprocessed_input else preprocessed_input[
+            'im_shape']
+        results = self.postprocess(
+            model_pred, topk=topk, batch_size=1, im_shape=im_shape)
 
-    def detector_postprocess(self, preds, preprocessed_inputs):
-        """ 对目标检测和实例分割结果做后处理
-        """
-        bboxes = {'bbox': (np.array(preds[0]), [[len(preds[0])]])}
-        bboxes['im_id'] = (np.array([[0]]).astype('int32'), [])
-        clsid2catid = dict({i: i for i in range(self.num_classes)})
-        xywh_results = paddlex.cv.models.utils.detection_eval.bbox2out(
-            [bboxes], clsid2catid)
-        results = list()
-        for xywh_res in xywh_results:
-            del xywh_res['image_id']
-            xywh_res['category'] = self.labels[xywh_res['category_id']]
-            results.append(xywh_res)
-        if len(preds) > 1:
-            im_shape = preprocessed_inputs['im_shape']
-            bboxes['im_shape'] = (im_shape, [])
-            bboxes['mask'] = (np.array(preds[1]), [[len(preds[1])]])
-            segm_results = paddlex.cv.models.utils.detection_eval.mask2out(
-                [bboxes], clsid2catid, self.mask_head_resolution)
-            import pycocotools.mask as mask_util
-            for i in range(len(results)):
-                results[i]['mask'] = mask_util.decode(segm_results[i][
-                    'segmentation'])
-        return results
+        return results[0]
 
-    def predict(self, image, topk=1, threshold=0.5):
+    def batch_predict(self, image_list, topk=1, thread_num=2):
         """ 图片预测
 
             Args:
-                image(str|np.ndarray): 图片路径或np.ndarray格式,如果后者,要求为BGR输入格式
+                image(str|np.ndarray|list|tuple): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+                    或者是对数(元)组中的图像同时进行预测,数组中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
+                    且类型为float32且为BGR格式的数组。
                 topk(int): 分类预测时使用,表示预测前topk的结果
         """
-        preprocessed_input = self.preprocess(image)
+        preprocessed_input = self.preprocess(image_list)
         model_pred = self.raw_predict(preprocessed_input)
+        im_shape = None if 'im_shape' in preprocessed_input else preprocessed_input[
+            'im_shape']
+        results = self.postprocess(
+            model_pred, topk=topk, batch_size=1, im_shape=im_shape)
 
-        if self.model_type == "classifier":
-            results = self.classifier_postprocess(model_pred, topk)
-        elif self.model_type == "detector":
-            results = self.detector_postprocess(model_pred, preprocessed_input)
-        elif self.model_type == "segmenter":
-            results = self.segmenter_postprocess(model_pred,
-                                                 preprocessed_input)
         return results