浏览代码

add batch predict in classification task

jack 5 年之前
父节点
当前提交
cb6edab618
共有 4 个文件被更改,包括 146 次插入21 次删除
  1. 50 9
      deploy/cpp/demo/classifier.cpp
  2. 10 3
      deploy/cpp/include/paddlex/paddlex.h
  3. 3 3
      deploy/cpp/scripts/build.sh
  4. 83 6
      deploy/cpp/src/paddlex.cpp

+ 50 - 9
deploy/cpp/demo/classifier.cpp

@@ -14,13 +14,17 @@
 
 #include <glog/logging.h>
 
+#include <algorithm>
+#include <chrono>
 #include <fstream>
 #include <iostream>
 #include <string>
 #include <vector>
-
+#include <utility>
 #include "include/paddlex/paddlex.h"
 
+using namespace std::chrono;
+
 DEFINE_string(model_dir, "", "Path of inference model");
 DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
 DEFINE_bool(use_trt, false, "Infering with TensorRT");
@@ -28,6 +32,7 @@ DEFINE_int32(gpu_id, 0, "GPU card id");
 DEFINE_string(key, "", "key of encryption");
 DEFINE_string(image, "", "Path of test image file");
 DEFINE_string(image_list, "", "Path of test image list file");
+DEFINE_int32(batch_size, 1, "Batch size when infering");
 
 int main(int argc, char** argv) {
   // Parsing command-line
@@ -44,32 +49,68 @@ int main(int argc, char** argv) {
 
   // 加载模型
   PaddleX::Model model;
-  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key);
+  model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, FLAGS_key, FLAGS_batch_size);
 
   // 进行预测
+  double total_running_time_s = 0.0;
+  double total_imreaad_time_s = 0.0;
+
   if (FLAGS_image_list != "") {
     std::ifstream inf(FLAGS_image_list);
     if (!inf) {
       std::cerr << "Fail to open file " << FLAGS_image_list << std::endl;
       return -1;
     }
+    // 多batch预测
     std::string image_path;
+    std::vector<std::string> image_path_vec;
     while (getline(inf, image_path)) {
-      PaddleX::ClsResult result;
-      cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
-      std::cout << "Predict label: " << result.category
-                << ", label_id:" << result.category_id
-                << ", score: " << result.score << std::endl;
+      image_path_vec.push_back(image_path);
+    }
+    for(int i = 0; i < image_path_vec.size(); i += FLAGS_batch_size) {
+      auto start = system_clock::now();
+        // 读图像
+      int im_vec_size = std::min((int)image_path_vec.size(), i + FLAGS_batch_size);      
+      std::vector<cv::Mat> im_vec(im_vec_size - i);
+      std::vector<PaddleX::ClsResult> results(im_vec_size - i, PaddleX::ClsResult());
+      #pragma omp parallel for num_threads(im_vec_size - i)
+      for(int j = i; j < im_vec_size; ++j){
+        im_vec[j - i] = std::move(cv::imread(image_path_vec[j], 1));
+      }
+      auto imread_end = system_clock::now();
+      model.predict(im_vec, results);
+
+      auto imread_duration = duration_cast<microseconds>(imread_end - start);
+      total_imreaad_time_s += double(imread_duration.count()) * microseconds::period::num / microseconds::period::den;
+
+      auto end = system_clock::now();
+      auto duration = duration_cast<microseconds>(end - start);
+      total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
+      for(int j = i; j < im_vec_size; ++j) {
+            std::cout << "Path:" << image_path_vec[j]
+                      << ", predict label: " << results[j - i].category
+                      << ", label_id:" << results[j - i].category_id
+                      << ", score: " << results[j - i].score << std::endl;
+      }	
     }
   } else {
+    auto start = system_clock::now();
     PaddleX::ClsResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
     model.predict(im, &result);
+    auto end = system_clock::now();
+    auto duration = duration_cast<microseconds>(end - start);
+    total_running_time_s += double(duration.count()) * microseconds::period::num / microseconds::period::den;
     std::cout << "Predict label: " << result.category
               << ", label_id:" << result.category_id
               << ", score: " << result.score << std::endl;
   }
-
+  std::cout << "Total average running time: " 
+	    << total_running_time_s
+	    << " s, total average read img time: " 
+	    << total_imreaad_time_s
+	    << " s, batch_size = " 
+	    << FLAGS_batch_size 
+	    << std::endl;
   return 0;
 }

+ 10 - 3
deploy/cpp/include/paddlex/paddlex.h

@@ -45,22 +45,28 @@ class Model {
             bool use_gpu = false,
             bool use_trt = false,
             int gpu_id = 0,
-            std::string key = "") {
-    create_predictor(model_dir, use_gpu, use_trt, gpu_id, key);
+            std::string key = "",
+	    int batch_size = 1) {
+    create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, batch_size);
   }
 
   void create_predictor(const std::string& model_dir,
                         bool use_gpu = false,
                         bool use_trt = false,
                         int gpu_id = 0,
-                        std::string key = "");
+                        std::string key = "",
+			int batch_size = 1);
 
   bool load_config(const std::string& model_dir);
 
   bool preprocess(const cv::Mat& input_im, ImageBlob* blob);
+  
+  bool preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch);
 
   bool predict(const cv::Mat& im, ClsResult* result);
 
+  bool predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results);
+
   bool predict(const cv::Mat& im, DetResult* result);
 
   bool predict(const cv::Mat& im, SegResult* result);
@@ -74,6 +80,7 @@ class Model {
   std::map<int, std::string> labels;
   Transforms transforms_;
   ImageBlob inputs_;
+  std::vector<ImageBlob> inputs_batch_;
   std::vector<float> outputs_;
   std::unique_ptr<paddle::PaddlePredictor> predictor_;
 };

+ 3 - 3
deploy/cpp/scripts/build.sh

@@ -1,5 +1,5 @@
 # 是否使用GPU(即是否使用 CUDA)
-WITH_GPU=OFF
+WITH_GPU=ON
 # 使用MKL or openblas
 WITH_MKL=ON
 # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
@@ -7,7 +7,7 @@ WITH_TENSORRT=OFF
 # TensorRT 的路径
 TENSORRT_DIR=/path/to/TensorRT/
 # Paddle 预测库路径
-PADDLE_DIR=/docker/jiangjiajun/PaddleDetection/deploy/cpp/fluid_inference
+PADDLE_DIR=/mnt/zhoushunjie/projects/fluid_inference
 # Paddle 的预测库是否使用静态库来编译
 # 使用TensorRT时,Paddle的预测库通常为动态库
 WITH_STATIC_LIB=OFF
@@ -42,4 +42,4 @@ cmake .. \
     -DCUDNN_LIB=${CUDNN_LIB} \
     -DENCRYPTION_DIR=${ENCRYPTION_DIR} \
     -DOPENCV_DIR=${OPENCV_DIR}
-make
+make -j4

+ 83 - 6
deploy/cpp/src/paddlex.cpp

@@ -11,16 +11,17 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-
+#include <algorithm>
+#include <omp.h>
 #include "include/paddlex/paddlex.h"
-
 namespace PaddleX {
 
 void Model::create_predictor(const std::string& model_dir,
                              bool use_gpu,
                              bool use_trt,
                              int gpu_id,
-                             std::string key) {
+                             std::string key,
+			     int batch_size) {
   // 读取配置文件
   if (!load_config(model_dir)) {
     std::cerr << "Parse file 'model.yml' failed!" << std::endl;
@@ -58,6 +59,7 @@ void Model::create_predictor(const std::string& model_dir,
         false /* use_calib_mode*/);
   }
   predictor_ = std::move(CreatePaddlePredictor(config));
+  inputs_batch_.assign(batch_size, ImageBlob());
 }
 
 bool Model::load_config(const std::string& model_dir) {
@@ -104,6 +106,21 @@ bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
   return true;
 }
 
+// use openmp
+bool Model::preprocess(const std::vector<cv::Mat> &input_im_batch, std::vector<ImageBlob> &blob_batch) {
+  int batch_size = inputs_batch_.size();
+  bool success = true;
+  //int i;
+  #pragma omp parallel for num_threads(batch_size)
+  for(int i = 0; i < input_im_batch.size(); ++i) {
+    cv::Mat im = input_im_batch[i].clone();
+    if(!transforms_.Run(&im, &blob_batch[i])){
+      success = false;
+    }
+  }
+  return success;
+}
+
 bool Model::predict(const cv::Mat& im, ClsResult* result) {
   inputs_.clear();
   if (type == "detector") {
@@ -146,6 +163,64 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   result->category = labels[result->category_id];
 }
 
+bool Model::predict(const std::vector<cv::Mat> &im_batch, std::vector<ClsResult> &results) {
+  for(auto &inputs: inputs_batch_) {
+    inputs.clear();
+  }
+  if (type == "detector") {
+    std::cerr << "Loading model is a 'detector', DetResult should be passed to "
+                 "function predict()!"
+              << std::endl;
+    return false;
+  } else if (type == "segmenter") {
+    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
+                 "to function predict()!"
+              << std::endl;
+    return false;
+  }
+  // 处理输入图像
+  if (!preprocess(im_batch, inputs_batch_)) {
+    std::cerr << "Preprocess failed!" << std::endl;
+    return false;
+  }
+  // 使用加载的模型进行预测
+  int batch_size = im_batch.size();
+  auto in_tensor = predictor_->GetInputTensor("image");
+  int h = inputs_batch_[0].new_im_size_[0];
+  int w = inputs_batch_[0].new_im_size_[1];
+  in_tensor->Reshape({batch_size, 3, h, w});
+  std::vector<float> inputs_data(batch_size * 3 * h * w);
+  for(int i = 0; i <inputs_batch_.size(); ++i) {
+    std::copy(inputs_batch_[i].im_data_.begin(), inputs_batch_[i].im_data_.end(), inputs_data.begin() + i * 3 * h * w);
+  }
+  in_tensor->copy_from_cpu(inputs_data.data());
+  //in_tensor->copy_from_cpu(inputs_.im_data_.data());
+  predictor_->ZeroCopyRun();
+  // 取出模型的输出结果
+  auto output_names = predictor_->GetOutputNames();
+  auto output_tensor = predictor_->GetOutputTensor(output_names[0]);
+  std::vector<int> output_shape = output_tensor->shape();
+  int size = 1;
+  for (const auto& i : output_shape) {
+    size *= i;
+  }
+  outputs_.resize(size);
+  output_tensor->copy_to_cpu(outputs_.data());
+  // 对模型输出结果进行后处理
+  int single_batch_size = size / batch_size;
+  for(int i = 0; i < batch_size; ++i) {
+    auto start_ptr = std::begin(outputs_);
+    auto end_ptr = std::begin(outputs_);
+    std::advance(start_ptr, i * single_batch_size);
+    std::advance(end_ptr, (i + 1) * single_batch_size);
+    auto ptr = std::max_element(start_ptr, end_ptr);
+    results[i].category_id = std::distance(start_ptr, ptr);
+    results[i].score = *ptr;
+    results[i].category = labels[results[i].category_id];
+  }
+  return true;
+}
+
 bool Model::predict(const cv::Mat& im, DetResult* result) {
   result->clear();
   inputs_.clear();
@@ -288,6 +363,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
     size *= i;
     result->label_map.shape.push_back(i);
   }
+
   result->label_map.data.resize(size);
   output_label_tensor->copy_to_cpu(result->label_map.data.data());
 
@@ -299,6 +375,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
     size *= i;
     result->score_map.shape.push_back(i);
   }
+
   result->score_map.data.resize(size);
   output_score_tensor->copy_to_cpu(result->score_map.data.data());
 
@@ -325,8 +402,8 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
       inputs_.im_size_before_resize_.pop_back();
       auto padding_w = before_shape[0];
       auto padding_h = before_shape[1];
-      mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h));
-      mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h));
+      mask_label = mask_label(cv::Rect(0, 0, padding_h, padding_w));
+      mask_score = mask_score(cv::Rect(0, 0, padding_h, padding_w));
     } else if (*iter == "resize") {
       auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
       inputs_.im_size_before_resize_.pop_back();
@@ -343,7 +420,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
                  cv::Size(resize_h, resize_w),
                  0,
                  0,
-                 cv::INTER_NEAREST);
+                 cv::INTER_LINEAR); 
     }
     ++idx;
   }