Răsfoiți Sursa

use google style

jack 5 ani în urmă
părinte
comite
01e54a1f42

+ 40 - 31
deploy/cpp/demo/classifier.cpp

@@ -13,18 +13,18 @@
 // limitations under the License.
 
 #include <glog/logging.h>
+#include <omp.h>
 
 #include <algorithm>
-#include <chrono>
+#include <chrono>  // NOLINT
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
 #include <utility>
-#include <omp.h>
 #include "include/paddlex/paddlex.h"
 
-using namespace std::chrono;
+using namespace std::chrono;  // NOLINT
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
@@ -34,7 +34,9 @@ DEFINE_string(key, "", "key of encryption");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_int32(batch_size, 1, "Batch size of infering");
-DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
+DEFINE_int32(thread_num,
+             omp_get_num_procs(),
+             "Number of preprocessing threads");
 
 int main(int argc, char** argv) {
   // Parsing command-line
@@ -51,7 +53,12 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
+  model.Init(FLAGS_model_dir,
+             FLAGS_use_gpu,
+             FLAGS_use_trt,
+             FLAGS_gpu_id,
+             FLAGS_key,
+             FLAGS_batch_size);
 
   // 进行预测
   double total_running_time_s = 0.0;
@@ -70,32 +77,38 @@ int main(int argc, char** argv) {
       image_paths.push_back(image_path);
     }
     imgs = image_paths.size();
-    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
+    for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
       auto start = system_clock::now();
-        // 读图像
-      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);      
+      // 读图像
+      int im_vec_size =
+          std::min(static_cat<int>(image_paths.size()), i + FLAGS_batch_size);
       std::vector<cv::Mat> im_vec(im_vec_size - i);
-      std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult());
+      std::vector<PaddleX::ClsResult> results(im_vec_size - i,
+                                              PaddleX::ClsResult());
       int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
       #pragma omp parallel for num_threads(thread_num)
-      for(int j = i; j < im_vec_size; ++j){
+      for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
       auto imread_end = system_clock::now();
-      model.predict(im_vec, results, thread_num);
+      model.predict(im_vec, &results, thread_num);
 
       auto imread_duration = duration_cast<microseconds>(imread_end - start);
-      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+      total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                             microseconds::period::num /
+                             microseconds::period::den;
 
       auto end = system_clock::now();
       auto duration = duration_cast<microseconds>(end - start);
-      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
-      for(int j = i; j < im_vec_size; ++j) {
-            std::cout << "Path:" << image_paths[j]
-                      << ", predict label: " << results[j - i].category
-                      << ", label_id:" << results[j - i].category_id
-                      << ", score: " << results[j - i].score << std::endl;
-      }	
+      total_running_time_s += static_cast<double>(duration.count()) *
+                              microseconds::period::num /
+                              microseconds::period::den;
+      for (int j = i; j < im_vec_size; ++j) {
+        std::cout << "Path:" << image_paths[j]
+                  << ", predict label: " << results[j - i].category
+                  << ", label_id:" << results[j - i].category_id
+                  << ", score: " << results[j - i].score << std::endl;
+      }
     }
   } else {
     auto start = system_clock::now();
@@ -104,21 +117,17 @@ int main(int argc, char** argv) {
     model.predict(im, &result);
     auto end = system_clock::now();
     auto duration = duration_cast<microseconds>(end - start);
-    total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
+    total_running_time_s += static_cast<double>(duration.count()) *
+                            microseconds::period::num /
+                            microseconds::period::den;
     std::cout << "Predict label: " << result.category
               << ", label_id:" << result.category_id
               << ", score: " << result.score << std::endl;
   }
-  std::cout << "Total running time: " 
-            << total_running_time_s
-            << " s, average running time: "
-            << total_running_time_s / imgs 
-            << " s/img, total read img time: " 
-            << total_imread_time_s
-            << " s, average read time: "
-            << total_imread_time_s / imgs  
-            << " s/img, batch_size = " 
-            << FLAGS_batch_size 
-            << std::endl;
+  std::cout << "Total running time: " << total_running_time_s
+            << " s, average running time: " << total_running_time_s / imgs
+            << " s/img, total read img time: " << total_imread_time_s
+            << " s, average read time: " << total_imread_time_s / imgs
+            << " s/img, batch_size = " << FLAGS_batch_size << std::endl;
   return 0;
 }

+ 48 - 39
deploy/cpp/demo/detector.cpp

@@ -13,20 +13,20 @@
 // limitations under the License.
 
 #include <glog/logging.h>
+#include <omp.h>
 
 #include <algorithm>
-#include <chrono>
+#include <chrono>  // NOLINT
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
 #include <utility>
-#include <omp.h>
 
 #include "include/paddlex/paddlex.h"
 #include "include/paddlex/visualize.h"
 
-using namespace std::chrono;
+using namespace std::chrono;  // NOLINT
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
@@ -37,13 +37,17 @@ DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(save_dir, "output", "Path to save visualized image");
 DEFINE_int32(batch_size, 1, "Batch size of infering");
-DEFINE_double(threshold, 0.5, "The minimum scores of target boxes which are shown");
-DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
+DEFINE_double(threshold,
+              0.5,
+              "The minimum scores of target boxes which are shown");
+DEFINE_int32(thread_num,
+             omp_get_num_procs(),
+             "Number of preprocessing threads");
 
 int main(int argc, char** argv) {
   // 解析命令行参数
   google::ParseCommandLineFlags(&argc, &argv, true);
-  
+
   if (FLAGS_model_dir == "") {
     std::cerr << "--model_dir need to be defined" << std::endl;
     return -1;
@@ -55,7 +59,12 @@ int main(int argc, char** argv) {
   std::cout << "Thread num: " << FLAGS_thread_num << std::endl;
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
+  model.Init(FLAGS_model_dir,
+             FLAGS_use_gpu,
+             FLAGS_use_trt,
+             FLAGS_gpu_id,
+             FLAGS_key,
+             FLAGS_batch_size);
 
   double total_running_time_s = 0.0;
   double total_imread_time_s = 0.0;
@@ -75,41 +84,47 @@ int main(int argc, char** argv) {
       image_paths.push_back(image_path);
     }
     imgs = image_paths.size();
-    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
+    for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
       auto start = system_clock::now();
-      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
+      int im_vec_size =
+          std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
       std::vector<cv::Mat> im_vec(im_vec_size - i);
-      std::vector<PaddleX::DetResult> results(im_vec_size - i, PaddleX::DetResult());
+      std::vector<PaddleX::DetResult> results(im_vec_size - i,
+                                              PaddleX::DetResult());
       int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
       #pragma omp parallel for num_threads(thread_num)
-      for(int j = i; j < im_vec_size; ++j){
+      for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
       auto imread_end = system_clock::now();
-      model.predict(im_vec, results, thread_num);
+      model.predict(im_vec, &results, thread_num);
       auto imread_duration = duration_cast<microseconds>(imread_end - start);
-      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+      total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                             microseconds::period::num /
+                             microseconds::period::den;
       auto end = system_clock::now();
       auto duration = duration_cast<microseconds>(end - start);
-      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
-      //输出结果目标框
-      for(int j = 0; j < im_vec_size - i; ++j) {
-        for(int k = 0; k < results[j].boxes.size(); ++k) {
-          std::cout << "image file: " << image_paths[i + j] << ", ";// << std::endl;          
+      total_running_time_s += static_cast<double>(duration.count()) *
+                              microseconds::period::num /
+                              microseconds::period::den;
+      // 输出结果目标框
+      for (int j = 0; j < im_vec_size - i; ++j) {
+        for (int k = 0; k < results[j].boxes.size(); ++k) {
+          std::cout << "image file: " << image_paths[i + j] << ", ";
           std::cout << "predict label: " << results[j].boxes[k].category
                     << ", label_id:" << results[j].boxes[k].category_id
-                    << ", score: " << results[j].boxes[k].score << ", box(xmin, ymin, w, h):("
+                    << ", score: " << results[j].boxes[k].score
+                    << ", box(xmin, ymin, w, h):("
                     << results[j].boxes[k].coordinate[0] << ", "
                     << results[j].boxes[k].coordinate[1] << ", "
                     << results[j].boxes[k].coordinate[2] << ", "
                     << results[j].boxes[k].coordinate[3] << ")" << std::endl;
-          
         }
       }
       // 可视化
-      for(int j = 0; j < im_vec_size - i; ++j) {
-        cv::Mat vis_img =
-            PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap, FLAGS_threshold);
+      for (int j = 0; j < im_vec_size - i; ++j) {
+        cv::Mat vis_img = PaddleX::Visualize(
+            im_vec[j], results[j], model.labels, colormap, FLAGS_threshold);
         std::string save_path =
             PaddleX::generate_save_path(FLAGS_save_dir, image_paths[i + j]);
         cv::imwrite(save_path, vis_img);
@@ -121,12 +136,12 @@ int main(int argc, char** argv) {
     cv::Mat im = cv::imread(FLAGS_image, 1);
     model.predict(im, &result);
     for (int i = 0; i < result.boxes.size(); ++i) {
-      std::cout << "image file: " << FLAGS_image << std::endl;          
+      std::cout << "image file: " << FLAGS_image << std::endl;
       std::cout << ", predict label: " << result.boxes[i].category
                 << ", label_id:" << result.boxes[i].category_id
-                << ", score: " << result.boxes[i].score << ", box(xmin, ymin, w, h):("
-                << result.boxes[i].coordinate[0] << ", "
-                << result.boxes[i].coordinate[1] << ", "
+                << ", score: " << result.boxes[i].score
+                << ", box(xmin, ymin, w, h):(" << result.boxes[i].coordinate[0]
+                << ", " << result.boxes[i].coordinate[1] << ", "
                 << result.boxes[i].coordinate[2] << ", "
                 << result.boxes[i].coordinate[3] << ")" << std::endl;
     }
@@ -140,18 +155,12 @@ int main(int argc, char** argv) {
     result.clear();
     std::cout << "Visualized output saved as " << save_path << std::endl;
   }
-  
-  std::cout << "Total running time: " 
-            << total_running_time_s
-            << " s, average running time: "
-            << total_running_time_s / imgs
-            << " s/img, total read img time: " 
-            << total_imread_time_s
-            << " s, average read img time: "
-            << total_imread_time_s / imgs
-            << " s, batch_size = " 
-            << FLAGS_batch_size 
-            << std::endl;
+
+  std::cout << "Total running time: " << total_running_time_s
+            << " s, average running time: " << total_running_time_s / imgs
+            << " s/img, total read img time: " << total_imread_time_s
+            << " s, average read img time: " << total_imread_time_s / imgs
+            << " s, batch_size = " << FLAGS_batch_size << std::endl;
 
   return 0;
 }

+ 34 - 25
deploy/cpp/demo/segmenter.cpp

@@ -13,19 +13,19 @@
 // limitations under the License.
 
 #include <glog/logging.h>
+#include <omp.h>
 
 #include <algorithm>
-#include <chrono>
+#include <chrono>  // NOLINT
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
 #include <utility>
-#include <omp.h>
 #include "include/paddlex/paddlex.h"
 #include "include/paddlex/visualize.h"
 
-using namespace std::chrono;
+using namespace std::chrono;  // NOLINT
 
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
@@ -36,7 +36,9 @@ DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
 DEFINE_string(save_dir, "output", "Path to save visualized image");
 DEFINE_int32(batch_size, 1, "Batch size of infering");
-DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads");
+DEFINE_int32(thread_num,
+             omp_get_num_procs(),
+             "Number of preprocessing threads");
 
 int main(int argc, char** argv) {
   // 解析命令行参数
@@ -53,7 +55,12 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
+  model.Init(FLAGS_model_dir,
+             FLAGS_use_gpu,
+             FLAGS_use_trt,
+             FLAGS_gpu_id,
+             FLAGS_key,
+             FLAGS_batch_size);
 
   double total_running_time_s = 0.0;
   double total_imread_time_s = 0.0;
@@ -72,25 +79,31 @@ int main(int argc, char** argv) {
       image_paths.push_back(image_path);
     }
     imgs = image_paths.size();
-    for(int i = 0; i < image_paths.size(); i += FLAGS_batch_size){
+    for (int i = 0; i < image_paths.size(); i += FLAGS_batch_size) {
       auto start = system_clock::now();
-      int im_vec_size = std::min((int)image_paths.size(), i + FLAGS_batch_size);
+      int im_vec_size =
+          std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
       std::vector<cv::Mat> im_vec(im_vec_size - i);
-      std::vector<PaddleX::SegResult> results(im_vec_size - i, PaddleX::SegResult());
+      std::vector<PaddleX::SegResult> results(im_vec_size - i,
+                                              PaddleX::SegResult());
       int thread_num = std::min(FLAGS_thread_num, im_vec_size - i);
       #pragma omp parallel for num_threads(thread_num)
-      for(int j = i; j < im_vec_size; ++j){
+      for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
       auto imread_end = system_clock::now();
-      model.predict(im_vec, results, thread_num);
+      model.predict(im_vec, &results, thread_num);
       auto imread_duration = duration_cast<microseconds>(imread_end - start);
-      total_imread_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+      total_imread_time_s += static_cast<double>(imread_duration.count()) *
+                             microseconds::period::num /
+                             microseconds::period::den;
       auto end = system_clock::now();
       auto duration = duration_cast<microseconds>(end - start);
-      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
+      total_running_time_s += static_cast<double>(duration.count()) *
+                              microseconds::period::num /
+                              microseconds::period::den;
       // 可视化
-      for(int j = 0; j < im_vec_size - i; ++j) {
+      for (int j = 0; j < im_vec_size - i; ++j) {
         cv::Mat vis_img =
             PaddleX::Visualize(im_vec[j], results[j], model.labels, colormap);
         std::string save_path =
@@ -106,7 +119,9 @@ int main(int argc, char** argv) {
     model.predict(im, &result);
     auto end = system_clock::now();
     auto duration = duration_cast<microseconds>(end - start);
-    total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
+    total_running_time_s += static_cast<double>(duration.count()) *
+                            microseconds::period::num /
+                            microseconds::period::den;
     // 可视化
     cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
     std::string save_path =
@@ -115,17 +130,11 @@ int main(int argc, char** argv) {
     result.clear();
     std::cout << "Visualized output saved as " << save_path << std::endl;
   }
-  std::cout << "Total running time: " 
-            << total_running_time_s
-            << " s, average running time: "
-            << total_running_time_s / imgs
-            << " s/img, total read img time: " 
-            << total_imread_time_s
-            << " s, average read img time: "
-            << total_imread_time_s / imgs
-            << " s, batch_size = " 
-            << FLAGS_batch_size
-            << std::endl;
+  std::cout << "Total running time: " << total_running_time_s
+            << " s, average running time: " << total_running_time_s / imgs
+            << " s/img, total read img time: " << total_imread_time_s
+            << " s, average read img time: " << total_imread_time_s / imgs
+            << " s, batch_size = " << FLAGS_batch_size << std::endl;
 
   return 0;
 }

+ 1 - 1
deploy/cpp/include/paddlex/config_parser.h

@@ -54,4 +54,4 @@ class ConfigPaser {
   YAML::Node Transforms_;
 };
 
-}  // namespace PaddleDetection
+}  // namespace PaddleX

+ 39 - 28
deploy/cpp/include/paddlex/paddlex.h

@@ -16,8 +16,11 @@
 
 #include <functional>
 #include <iostream>
+#include <map>
+#include <memory>
 #include <numeric>
-
+#include <string>
+#include <vector>
 #include "yaml-cpp/yaml.h"
 
 #ifdef _WIN32
@@ -28,21 +31,21 @@
 
 #include "paddle_inference_api.h"  // NOLINT
 
-#include "config_parser.h"
-#include "results.h"
-#include "transforms.h"
+#include "config_parser.h"  // NOLINT
+#include "results.h"  // NOLINT
+#include "transforms.h"  // NOLINT
 
 #ifdef WITH_ENCRYPTION
-#include "paddle_model_decrypt.h"
-#include "model_code.h"
+#include "paddle_model_decrypt.h"  // NOLINT
+#include "model_code.h"  // NOLINT
 #endif
 
 namespace PaddleX {
 
 /*
  * @brief
- * This class encapsulates all necessary proccess steps of model infering, which 
- * include image matrix preprocessing, model predicting and results postprocessing. 
+ * This class encapsulates all necessary proccess steps of model infering, which
+ * include image matrix preprocessing, model predicting and results postprocessing.
  * The entire process of model infering can be simplified as below:
  * 1. preprocess image matrix (resize, padding, ......)
  * 2. model infer
@@ -63,11 +66,11 @@ class Model {
   /*
    * @brief
    * This method aims to initialize the model configuration
-   * 
+   *
    * @param model_dir: the directory which contains model.yml
    * @param use_gpu: use gpu or not when infering
    * @param use_trt: use Tensor RT or not when infering
-   * @param gpu_id: the id of gpu when infering with using gpu 
+   * @param gpu_id: the id of gpu when infering with using gpu
    * @param key: the key of encryption when using encrypted model
    * @param batch_size: batch size of infering
    * */
@@ -76,7 +79,7 @@ class Model {
             bool use_trt = false,
             int gpu_id = 0,
             std::string key = "",
-	          int batch_size = 1) {
+            int batch_size = 1) {
     create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size);
   }
 
@@ -85,11 +88,11 @@ class Model {
                         bool use_trt = false,
                         int gpu_id = 0,
                         std::string key = "",
-			                  int batch_size = 1);
- 
+                        int batch_size = 1);
+
   /*
-   * @brief 
-   * This method aims to load model configurations which include 
+   * @brief
+   * This method aims to load model configurations which include
    * transform steps and label list
    *
    * @param model_dir: the directory which contains model.yml
@@ -107,7 +110,7 @@ class Model {
    * @return true if preprocess image matrix successfully
    * */
   bool preprocess(const cv::Mat& input_im, ImageBlob* blob);
-  
+
   /*
    * @brief
    * This method aims to transform mutiple image matrixs, the result will be
@@ -115,15 +118,17 @@ class Model {
    *
    * @param input_im_batch: a batch of image matrixs to be transformed
    * @param blob_blob: raw data of a batch of image matrixs after transformed
-   * @param thread_num: the number of preprocessing threads, 
+   * @param thread_num: the number of preprocessing threads,
    *                    each thread run preprocess on single image matrix
    * @return true if preprocess a batch of image matrixs successfully
    * */
-  bool preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch, int thread_num = 1);
+  bool preprocess(const std::vector<cv::Mat> &input_im_batch,
+                  std::vector<ImageBlob> *blob_batch,
+                  int thread_num = 1);
 
   /*
    * @brief
-   * This method aims to execute classification model prediction on single image matrix, 
+   * This method aims to execute classification model prediction on single image matrix,
    * the result will be returned at second parameter.
    *
    * @param im: single image matrix to be predicted
@@ -134,7 +139,7 @@ class Model {
 
   /*
    * @brief
-   * This method aims to execute classification model prediction on a batch of image matrixs, 
+   * This method aims to execute classification model prediction on a batch of image matrixs,
    * the result will be returned at second parameter.
    *
    * @param im: a batch of image matrixs to be predicted
@@ -143,7 +148,9 @@ class Model {
    *                    on single image matrix
    * @return true if predict successfully
    * */
-  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results, int thread_num = 1);
+  bool predict(const std::vector<cv::Mat> &im_batch,
+               std::vector<ClsResult> *results,
+               int thread_num = 1);
 
   /*
    * @brief
@@ -167,11 +174,13 @@ class Model {
    *                    on single image matrix
    * @return true if predict successfully
    * */
-  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result, int thread_num = 1);
-  
+  bool predict(const std::vector<cv::Mat> &im_batch,
+               std::vector<DetResult> *result,
+               int thread_num = 1);
+
   /*
    * @brief
-   * This method aims to execute segmentation model prediction on single image matrix, 
+   * This method aims to execute segmentation model prediction on single image matrix,
    * the result will be returned at second parameter.
    *
    * @param im: single image matrix to be predicted
@@ -182,7 +191,7 @@ class Model {
 
   /*
    * @brief
-   * This method aims to execute segmentation model prediction on a batch of image matrix, 
+   * This method aims to execute segmentation model prediction on a batch of image matrix,
    * the result will be returned at second parameter.
    *
    * @param im: a batch of image matrix to be predicted
@@ -191,8 +200,10 @@ class Model {
    *                    on single image matrix
    * @return true if predict successfully
    * */
-  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result, int thread_num = 1);
- 
+  bool predict(const std::vector<cv::Mat> &im_batch,
+               std::vector<SegResult> *result,
+               int thread_num = 1);
+
   // model type, include 3 type: classifier, detector, segmenter
   std::string type;
   // model name, such as FasterRCNN, YOLOV3 and so on.
@@ -209,4 +220,4 @@ class Model {
   // a predictor which run the model predicting
   std::unique_ptr<paddle::PaddlePredictor> predictor_;
 };
-}  // namespce of PaddleX
+}  // namespace PaddleX

+ 5 - 3
deploy/cpp/include/paddlex/transforms.h

@@ -66,7 +66,7 @@ class Transform {
    * This method executes preprocessing operation on image matrix,
    * result will be returned at second parameter.
    * @param im: single image matrix to be preprocessed
-   * @param data: the raw data of single image matrix after preprocessed 
+   * @param data: the raw data of single image matrix after preprocessed
    * @return true if transform successfully
    * */
   virtual bool Run(cv::Mat* im, ImageBlob* data) = 0;
@@ -92,10 +92,10 @@ class Normalize : public Transform {
 
 /*
  * @brief
- * This class execute resize by short operation on image matrix. At first, it resizes 
+ * This class execute resize by short operation on image matrix. At first, it resizes
  * the short side of image matrix to specified length. Accordingly, the long side
  * will be resized in the same proportion. If new length of long side exceeds max
- * size, the long size will be resized to max size, and the short size will be 
+ * size, the long size will be resized to max size, and the short size will be
  * resized in the same proportion
  * */
 class ResizeByShort : public Transform {
@@ -214,6 +214,7 @@ class Padding : public Transform {
     }
   }
   virtual bool Run(cv::Mat* im, ImageBlob* data);
+
  private:
   int coarsest_stride_ = -1;
   int width_ = 0;
@@ -229,6 +230,7 @@ class Transforms {
   void Init(const YAML::Node& node, bool to_rgb = true);
   std::shared_ptr<Transform> CreateTransform(const std::string& name);
   bool Run(cv::Mat* im, ImageBlob* data);
+
  private:
   std::vector<std::shared_ptr<Transform>> transforms_;
   bool to_rgb_ = true;

+ 2 - 2
deploy/cpp/include/paddlex/visualize.h

@@ -47,7 +47,7 @@ namespace PaddleX {
  * @brief
  * Generate visualization colormap for each class
  *
- * @param number of class 
+ * @param number of class
  * @return color map, the size of vector is 3 * num_class
  * */
 std::vector<int> GenerateColorMap(int num_class);
@@ -94,4 +94,4 @@ cv::Mat Visualize(const cv::Mat& img,
  * */
 std::string generate_save_path(const std::string& save_dir,
                                const std::string& file_path);
-}  // namespce of PaddleX
+}  // namespace PaddleX

+ 130 - 115
deploy/cpp/src/paddlex.cpp

@@ -11,10 +11,10 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-#include <algorithm>
 #include <omp.h>
-#include "include/paddlex/paddlex.h"
+#include <algorithm>
 #include <cstring>
+#include "include/paddlex/paddlex.h"
 namespace PaddleX {
 
 void Model::create_predictor(const std::string& model_dir,
@@ -22,7 +22,7 @@ void Model::create_predictor(const std::string& model_dir,
                              bool use_trt,
                              int gpu_id,
                              std::string key,
-			     int batch_size) {
+                             int batch_size) {
   // 读取配置文件
   if (!load_config(model_dir)) {
     std::cerr << "Parse file 'model.yml' failed!" << std::endl;
@@ -32,13 +32,14 @@ void Model::create_predictor(const std::string& model_dir,
   std::string model_file = model_dir + OS_PATH_SEP + "__model__";
   std::string params_file = model_dir + OS_PATH_SEP + "__params__";
 #ifdef WITH_ENCRYPTION
-  if (key != ""){
+  if (key != "") {
     model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
     params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
-    paddle_security_load_model(&config, key.c_str(), model_file.c_str(), params_file.c_str());
+    paddle_security_load_model(
+        &config, key.c_str(), model_file.c_str(), params_file.c_str());
   }
 #endif
-  if (key == ""){
+  if (key == "") {
     config.SetModel(model_file, params_file);
   }
   if (use_gpu) {
@@ -70,11 +71,11 @@ bool Model::load_config(const std::string& model_dir) {
   name = config["Model"].as<std::string>();
   std::string version = config["version"].as<std::string>();
   if (version[0] == '0') {
-    std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, deployment "
-              << "cannot be done, please refer to "
-              << "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/tutorials/deploy/upgrade_version.md "
-              << "to transfer version."
-              << std::endl;
+    std::cerr << "[Init] Version of the loaded model is lower than 1.0.0, "
+              << "deployment cannot be done, please refer to "
+              << "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs"
+              << "/tutorials/deploy/upgrade_version.md "
+              << "to transfer version." << std::endl;
     return false;
   }
   bool to_rgb = true;
@@ -108,14 +109,16 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
 }
 
 // use openmp
-bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch, int thread_num) {
+bool Model::preprocess(const std::vector<cv::Mat>& input_im_batch,
+                       std::vector<ImageBlob>* blob_batch,
+                       int thread_num) {
   int batch_size = input_im_batch.size();
   bool success = true;
   thread_num = std::min(thread_num, batch_size);
   #pragma omp parallel for num_threads(thread_num)
-  for(int i = 0; i < input_im_batch.size(); ++i) {
+  for (int i = 0; i < input_im_batch.size(); ++i) {
     cv::Mat im = input_im_batch[i].clone();
-    if(!transforms_.Run(&im, &blob_batch[i])){
+    if (!transforms_.Run(&im, &(*blob_batch)[i])) {
       success = false;
     }
   }
@@ -127,8 +130,7 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
                  "function predict()!"
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   }
   // 处理输入图像
@@ -161,23 +163,23 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   return true;
 }
 
-bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results, int thread_num) {
-  for(auto &inputs: inputs_batch_) {
+bool Model::predict(const std::vector<cv::Mat>& im_batch,
+                    std::vector<ClsResult>* results,
+                    int thread_num) {
+  for (auto& inputs : inputs_batch_) {
     inputs.clear();
   }
   if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
-                 "function predict()!"
-              << std::endl;
+                 "function predict()!" << std::endl;
     return false;
   } else if (type == "segmenter") {
     std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   }
   // 处理输入图像
-  if (!preprocess(im_batch, inputs_batch_, thread_num)) {
+  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
     std::cerr << "Preprocess failed!" << std::endl;
     return false;
   }
@@ -188,11 +190,13 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
   int w = inputs_batch_[0].new_im_size_[1];
   in_tensor->Reshape({batch_size, 3, h, w});
   std::vector<float> inputs_data(batch_size * 3 * h * w);
-  for(int i = 0; i < batch_size; ++i) {
-    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  for (int i = 0; i < batch_size; ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(),
+              inputs_batch_[i].im_data_.end(),
+              inputs_data.begin() + i * 3 * h * w);
   }
   in_tensor->copy_from_cpu(inputs_data.data());
-  //in_tensor->copy_from_cpu(inputs_.im_data_.data());
+  // in_tensor->copy_from_cpu(inputs_.im_data_.data());
   predictor_->ZeroCopyRun();
   // 取出模型的输出结果
   auto output_names = predictor_->GetOutputNames();
@@ -206,15 +210,15 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult>
   output_tensor->copy_to_cpu(outputs_.data());
   // 对模型输出结果进行后处理
   int single_batch_size = size / batch_size;
-  for(int i = 0; i < batch_size; ++i) {
+  for (int i = 0; i < batch_size; ++i) {
     auto start_ptr = std::begin(outputs_);
     auto end_ptr = std::begin(outputs_);
     std::advance(start_ptr, i * single_batch_size);
     std::advance(end_ptr, (i + 1) * single_batch_size);
     auto ptr = std::max_element(start_ptr, end_ptr);
-    results[i].category_id = std::distance(start_ptr, ptr);
-    results[i].score = *ptr;
-    results[i].category = labels[results[i].category_id];
+    (*results)[i].category_id = std::distance(start_ptr, ptr);
+    (*results)[i].score = *ptr;
+    (*results)[i].category = labels[(*results)[i].category_id];
   }
   return true;
 }
@@ -224,13 +228,11 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
   result->clear();
   if (type == "classifier") {
     std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   } else if (type == "segmenter") {
     std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   }
 
@@ -324,25 +326,25 @@ bool Model::predict(const cv::Mat& im, DetResult* result) {
   return true;
 }
 
-bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult> &result, int thread_num) {
-  for(auto &inputs: inputs_batch_) {
+bool Model::predict(const std::vector<cv::Mat>& im_batch,
+                    std::vector<DetResult>* result,
+                    int thread_num) {
+  for (auto& inputs : inputs_batch_) {
     inputs.clear();
   }
   if (type == "classifier") {
     std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   } else if (type == "segmenter") {
     std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   }
 
   int batch_size = im_batch.size();
   // 处理输入图像
-  if (!preprocess(im_batch, inputs_batch_, thread_num)) {
+  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
     std::cerr << "Preprocess failed!" << std::endl;
     return false;
   }
@@ -351,33 +353,34 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
     if (name == "FasterRCNN" || name == "MaskRCNN") {
       int max_h = -1;
       int max_w = -1;
-      for(int i = 0; i < batch_size; ++i) {
+      for (int i = 0; i < batch_size; ++i) {
         max_h = std::max(max_h, inputs_batch_[i].new_im_size_[0]);
         max_w = std::max(max_w, inputs_batch_[i].new_im_size_[1]);
-        //std::cout << "(" << inputs_batch_[i].new_im_size_[0] 
-        //          << ", " << inputs_batch_[i].new_im_size_[1] 
+        // std::cout << "(" << inputs_batch_[i].new_im_size_[0]
+        //          << ", " << inputs_batch_[i].new_im_size_[1]
         //          <<  ")" << std::endl;
       }
       thread_num = std::min(thread_num, batch_size);
       #pragma omp parallel for num_threads(thread_num)
-      for(int i = 0; i < batch_size; ++i) {
+      for (int i = 0; i < batch_size; ++i) {
         int h = inputs_batch_[i].new_im_size_[0];
         int w = inputs_batch_[i].new_im_size_[1];
         int c = im_batch[i].channels();
-        if(max_h != h || max_w != w) {
+        if (max_h != h || max_w != w) {
           std::vector<float> temp_buffer(c * max_h * max_w);
-          float *temp_ptr = temp_buffer.data();
-          float *ptr = inputs_batch_[i].im_data_.data();
-          for(int cur_channel = c - 1; cur_channel >= 0; --cur_channel) {
+          float* temp_ptr = temp_buffer.data();
+          float* ptr = inputs_batch_[i].im_data_.data();
+          for (int cur_channel = c - 1; cur_channel >= 0; --cur_channel) {
             int ori_pos = cur_channel * h * w + (h - 1) * w;
             int des_pos = cur_channel * max_h * max_w + (h - 1) * max_w;
-            for(int start_pos = ori_pos; start_pos >= cur_channel * h * w; start_pos -= w, des_pos -= max_w) {
-              memcpy(temp_ptr + des_pos, ptr + start_pos, w * sizeof(float));
+            int last_pos = cur_channel * h * w;
+            for (; ori_pos >= last_pos; ori_pos -= w, des_pos -= max_w) {
+              memcpy(temp_ptr + des_pos, ptr + ori_pos, w * sizeof(float));
             }
           }
           inputs_batch_[i].im_data_.swap(temp_buffer);
           inputs_batch_[i].new_im_size_[0] = max_h;
-          inputs_batch_[i].new_im_size_[1] = max_w; 
+          inputs_batch_[i].new_im_size_[1] = max_w;
         }
       }
     }
@@ -387,16 +390,20 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
   auto im_tensor = predictor_->GetInputTensor("image");
   im_tensor->Reshape({batch_size, 3, h, w});
   std::vector<float> inputs_data(batch_size * 3 * h * w);
-  for(int i = 0; i < batch_size; ++i) {
-    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  for (int i = 0; i < batch_size; ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(),
+              inputs_batch_[i].im_data_.end(),
+              inputs_data.begin() + i * 3 * h * w);
   }
   im_tensor->copy_from_cpu(inputs_data.data());
   if (name == "YOLOv3") {
     auto im_size_tensor = predictor_->GetInputTensor("im_size");
     im_size_tensor->Reshape({batch_size, 2});
-    std::vector<int> inputs_data_size(batch_size  * 2);
-    for(int i = 0; i < batch_size; ++i){
-      std::copy(inputs_batch_[i].ori_im_size_.begin(), inputs_batch_[i].ori_im_size_.end(), inputs_data_size.begin() + 2 * i);
+    std::vector<int> inputs_data_size(batch_size * 2);
+    for (int i = 0; i < batch_size; ++i) {
+      std::copy(inputs_batch_[i].ori_im_size_.begin(),
+                inputs_batch_[i].ori_im_size_.end(),
+                inputs_data_size.begin() + 2 * i);
     }
     im_size_tensor->copy_from_cpu(inputs_data_size.data());
   } else if (name == "FasterRCNN" || name == "MaskRCNN") {
@@ -404,10 +411,10 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
     auto im_shape_tensor = predictor_->GetInputTensor("im_shape");
     im_info_tensor->Reshape({batch_size, 3});
     im_shape_tensor->Reshape({batch_size, 3});
-    
+
     std::vector<float> im_info(3 * batch_size);
     std::vector<float> im_shape(3 * batch_size);
-    for(int i = 0; i < batch_size; ++i) {
+    for (int i = 0; i < batch_size; ++i) {
       float ori_h = static_cast<float>(inputs_batch_[i].ori_im_size_[0]);
       float ori_w = static_cast<float>(inputs_batch_[i].ori_im_size_[1]);
       float new_h = static_cast<float>(inputs_batch_[i].new_im_size_[0]);
@@ -444,9 +451,9 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
   int num_boxes = size / 6;
   // 解析预测框box
   for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
-    for(int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
+    for (int j = lod_vector[0][i]; j < lod_vector[0][i + 1]; ++j) {
       Box box;
-      box.category_id = static_cast<int> (round(output_box[j * 6]));
+      box.category_id = static_cast<int>(round(output_box[j * 6]));
       box.category = labels[box.category_id];
       box.score = output_box[j * 6 + 1];
       float xmin = output_box[j * 6 + 2];
@@ -456,7 +463,7 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
       float w = xmax - xmin + 1;
       float h = ymax - ymin + 1;
       box.coordinate = {xmin, ymin, w, h};
-      result[i].boxes.push_back(std::move(box));
+      (*result)[i].boxes.push_back(std::move(box));
     }
   }
 
@@ -474,11 +481,13 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<DetResult>
     output_mask.resize(masks_size);
     output_mask_tensor->copy_to_cpu(output_mask.data());
     int mask_idx = 0;
-    for(int i = 0; i < lod_vector[0].size() - 1; ++i) {
-      result[i].mask_resolution = output_mask_shape[2];
-      for(int j = 0; j < result[i].boxes.size(); ++j) {
-        Box* box = &result[i].boxes[j];
-        auto begin_mask = output_mask.begin() + (mask_idx * classes + box->category_id) * mask_pixels;
+    for (int i = 0; i < lod_vector[0].size() - 1; ++i) {
+      (*result)[i].mask_resolution = output_mask_shape[2];
+      for (int j = 0; j < (*result)[i].boxes.size(); ++j) {
+        Box* box = &(*result)[i].boxes[j];
+        int category_id = box->category_id;
+        auto begin_mask = output_mask.begin() +
+                          (mask_idx * classes + category_id) * mask_pixels;
         auto end_mask = begin_mask + mask_pixels;
         box->mask.data.assign(begin_mask, end_mask);
         box->mask.shape = {static_cast<int>(box->coordinate[2]),
@@ -495,13 +504,11 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
   inputs_.clear();
   if (type == "classifier") {
     std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   } else if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
-                 "function predict()!"
-              << std::endl;
+                 "function predict()!" << std::endl;
     return false;
   }
 
@@ -586,7 +593,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                  cv::Size(resize_h, resize_w),
                  0,
                  0,
-                 cv::INTER_LINEAR); 
+                 cv::INTER_LINEAR);
     }
     ++idx;
   }
@@ -599,41 +606,43 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
   return true;
 }
 
-bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult> &result, int thread_num) {
-  for(auto &inputs: inputs_batch_) {
+bool Model::predict(const std::vector<cv::Mat>& im_batch,
+                    std::vector<SegResult>* result,
+                    int thread_num) {
+  for (auto& inputs : inputs_batch_) {
     inputs.clear();
   }
   if (type == "classifier") {
     std::cerr << "Loading model is a 'classifier', ClsResult should be passed "
-                 "to function predict()!"
-              << std::endl;
+                 "to function predict()!" << std::endl;
     return false;
   } else if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
-                 "function predict()!"
-              << std::endl;
+                 "function predict()!" << std::endl;
     return false;
   }
 
   // 处理输入图像
-  if (!preprocess(im_batch, inputs_batch_, thread_num)) {
+  if (!preprocess(im_batch, &inputs_batch_, thread_num)) {
     std::cerr << "Preprocess failed!" << std::endl;
     return false;
   }
 
   int batch_size = im_batch.size();
-  result.clear();
-  result.resize(batch_size);
+  (*result).clear();
+  (*result).resize(batch_size);
   int h = inputs_batch_[0].new_im_size_[0];
   int w = inputs_batch_[0].new_im_size_[1];
   auto im_tensor = predictor_->GetInputTensor("image");
   im_tensor->Reshape({batch_size, 3, h, w});
   std::vector<float> inputs_data(batch_size * 3 * h * w);
-  for(int i = 0; i < batch_size; ++i) {
-    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  for (int i = 0; i < batch_size; ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(),
+              inputs_batch_[i].im_data_.end(),
+              inputs_data.begin() + i * 3 * h * w);
   }
   im_tensor->copy_from_cpu(inputs_data.data());
-  //im_tensor->copy_from_cpu(inputs_.im_data_.data());
+  // im_tensor->copy_from_cpu(inputs_.im_data_.data());
 
   // 使用加载的模型进行预测
   predictor_->ZeroCopyRun();
@@ -652,13 +661,15 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
   auto output_labels_iter = output_labels.begin();
 
   int single_batch_size = size / batch_size;
-  for(int i = 0; i < batch_size; ++i) {
-    result[i].label_map.data.resize(single_batch_size);
-    result[i].label_map.shape.push_back(1);
-    for(int j = 1; j < output_label_shape.size(); ++j) {
-      result[i].label_map.shape.push_back(output_label_shape[j]);
+  for (int i = 0; i < batch_size; ++i) {
+    (*result)[i].label_map.data.resize(single_batch_size);
+    (*result)[i].label_map.shape.push_back(1);
+    for (int j = 1; j < output_label_shape.size(); ++j) {
+      (*result)[i].label_map.shape.push_back(output_label_shape[j]);
     }
-    std::copy(output_labels_iter + i * single_batch_size, output_labels_iter + (i + 1) * single_batch_size, result[i].label_map.data.data());
+    std::copy(output_labels_iter + i * single_batch_size,
+              output_labels_iter + (i + 1) * single_batch_size,
+              (*result)[i].label_map.data.data());
   }
 
   // 获取预测置信度scoremap
@@ -674,28 +685,30 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
   auto output_scores_iter = output_scores.begin();
 
   int single_batch_score_size = size / batch_size;
-  for(int i = 0; i < batch_size; ++i) {
-    result[i].score_map.data.resize(single_batch_score_size);
-    result[i].score_map.shape.push_back(1);
-    for(int j = 1; j < output_score_shape.size(); ++j) {
-      result[i].score_map.shape.push_back(output_score_shape[j]);
+  for (int i = 0; i < batch_size; ++i) {
+    (*result)[i].score_map.data.resize(single_batch_score_size);
+    (*result)[i].score_map.shape.push_back(1);
+    for (int j = 1; j < output_score_shape.size(); ++j) {
+      (*result)[i].score_map.shape.push_back(output_score_shape[j]);
     }
-    std::copy(output_scores_iter + i * single_batch_score_size, output_scores_iter + (i + 1) * single_batch_score_size, result[i].score_map.data.data());
+    std::copy(output_scores_iter + i * single_batch_score_size,
+              output_scores_iter + (i + 1) * single_batch_score_size,
+              (*result)[i].score_map.data.data());
   }
 
   // 解析输出结果到原图大小
-  for(int i = 0; i < batch_size; ++i) {
-    std::vector<uint8_t> label_map(result[i].label_map.data.begin(),
-                                   result[i].label_map.data.end());
-    cv::Mat mask_label(result[i].label_map.shape[1],
-                       result[i].label_map.shape[2],
+  for (int i = 0; i < batch_size; ++i) {
+    std::vector<uint8_t> label_map((*result)[i].label_map.data.begin(),
+                                   (*result)[i].label_map.data.end());
+    cv::Mat mask_label((*result)[i].label_map.shape[1],
+                       (*result)[i].label_map.shape[2],
                        CV_8UC1,
                        label_map.data());
-  
-    cv::Mat mask_score(result[i].score_map.shape[2],
-                       result[i].score_map.shape[3],
+
+    cv::Mat mask_score((*result)[i].score_map.shape[2],
+                       (*result)[i].score_map.shape[3],
                        CV_32FC1,
-                       result[i].score_map.data.data());
+                       (*result)[i].score_map.data.data());
     int idx = 1;
     int len_postprocess = inputs_batch_[i].im_size_before_resize_.size();
     for (std::vector<std::string>::reverse_iterator iter =
@@ -703,14 +716,16 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
          iter != inputs_batch_[i].reshape_order_.rend();
          ++iter) {
       if (*iter == "padding") {
-        auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
+        auto before_shape =
+            inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
         inputs_batch_[i].im_size_before_resize_.pop_back();
         auto padding_w = before_shape[0];
         auto padding_h = before_shape[1];
         mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
         mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
       } else if (*iter == "resize") {
-        auto before_shape = inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
+        auto before_shape =
+            inputs_batch_[i].im_size_before_resize_[len_postprocess - idx];
         inputs_batch_[i].im_size_before_resize_.pop_back();
         auto resize_w = before_shape[0];
         auto resize_h = before_shape[1];
@@ -725,18 +740,18 @@ bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<SegResult>
                    cv::Size(resize_h, resize_w),
                    0,
                    0,
-                   cv::INTER_LINEAR); 
+                   cv::INTER_LINEAR);
       }
       ++idx;
     }
-    result[i].label_map.data.assign(mask_label.begin<uint8_t>(),
-                                  mask_label.end<uint8_t>());
-    result[i].label_map.shape = {mask_label.rows, mask_label.cols};
-    result[i].score_map.data.assign(mask_score.begin<float>(),
-                                  mask_score.end<float>());
-    result[i].score_map.shape = {mask_score.rows, mask_score.cols};
+    (*result)[i].label_map.data.assign(mask_label.begin<uint8_t>(),
+                                       mask_label.end<uint8_t>());
+    (*result)[i].label_map.shape = {mask_label.rows, mask_label.cols};
+    (*result)[i].score_map.data.assign(mask_score.begin<float>(),
+                                       mask_score.end<float>());
+    (*result)[i].score_map.shape = {mask_score.rows, mask_score.cols};
   }
   return true;
 }
 
-}  // namespce of PaddleX
+}  // namespace PaddleX

+ 1 - 1
deploy/cpp/src/visualize.cpp

@@ -145,4 +145,4 @@ std::string generate_save_path(const std::string& save_dir,
   std::string image_name(file_path.substr(pos + 1));
   return save_dir + OS_PATH_SEP + image_name;
 }
-}  // namespace of PaddleX
+}  // namespace PaddleX