Bladeren bron

update faiss api

gaotingquan 1 jaar geleden
bovenliggende
commit
942524ce63

+ 54 - 24
docs/pipeline_usage/tutorials/cv_pipelines/general_image_recognition.md

@@ -107,16 +107,17 @@ PaddleX 所提供的预训练的模型产线均可以快速体验效果,你可
 
 #### 2.2.2 Python脚本方式集成
 
-* 在该产线的运行示例中需要预先构建特征向量库,您可以下载官方提供的饮料识别测试数据集[drink_dataset_v2.0](https://paddle-model-ecology.bj.bcebos.com/paddlex/data/drink_dataset_v2.0.tar) 构建特征向量库。若您希望用私有数据集,可以参考[2.3节 构建特征库的数据组织方式](#23-构建特征库的数据组织方式)。之后通过几行代码即可完成建立特征向量库和通用图像识别产线的快速推理。
+* 在该产线的运行示例中需要预先构建索引库,您可以下载官方提供的饮料识别测试数据集[drink_dataset_v2.0](https://paddle-model-ecology.bj.bcebos.com/paddlex/data/drink_dataset_v2.0.tar) 构建索引库。若您希望用私有数据集,可以参考[2.3节 构建索引库的数据组织方式](#23-构建索引库的数据组织方式)。之后通过几行代码即可完成建立索引库和通用图像识别产线的快速推理。
 
 ```python
 from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="PP-ShiTuV2")
 
-pipeline.build_index(data_root="drink_dataset_v2.0/", index_dir="index_dir")
+index_data = pipeline.build_index(data_root="drink_dataset_v2.0/", label_path="drink_dataset_v2.0/gallery.txt")
+index_data.save("drink.index")
 
-output = pipeline.predict("./drink_dataset_v2.0/test_images/", index_dir="index_dir")
+output = pipeline.predict("./drink_dataset_v2.0/test_images/", index=index_data)
 for res in output:
     res.print()
     res.save_to_img("./output/")
@@ -143,8 +144,8 @@ for res in output:
 <td>无</td>
 </tr>
 <tr>
-<td><code>index_dir</code></td>
-<td>产线推理预测所用的检索库文件所在的目录,如不传入该参数,则需要在<code>predict()</code>中指定<code>index_dir</code>。</td>
+<td><code>index</code></td>
+<td>产线推理预测所用的索引库文件路径,如不传入该参数,则需要在<code>predict()</code>中指定<code>index</code>。</td>
 <td><code>str</code></td>
 <td>None</td>
 </tr>
@@ -162,7 +163,7 @@ for res in output:
 </tr>
 </tbody>
 </table>
-(2)调用通用图像识别产线对象的 `build_index` 方法,构建特征向量库。具体参数为说明如下:
+(2)调用通用图像识别产线对象的 `build_index` 方法,构建索引库。具体参数为说明如下:
 
 <table>
 <thead>
@@ -176,18 +177,40 @@ for res in output:
 <tbody>
 <tr>
 <td><code>data_root</code></td>
-<td>数据集的根目录,数据组织方式参考<a href="#2.3-构建特征库的数据组织方式">2.3节 构建特征库的数据组织方式</a></td>
+<td>数据集的根目录,数据组织方式参考<a href="#2.3-构建索引库的数据组织方式">2.3节 构建索引库的数据组织方式</a></td>
 <td><code>str</code></td>
 <td>无</td>
 </tr>
 <tr>
-<td><code>index_dir</code></td>
-<td>特征库的保存路径。成功调用<code>build_index</code>方法后会在改路径下生成两个文件: <code>"id_map.pkl"</code> 保存了图像ID与图像特征标签之间的映射关系;<code>“vector.index”</code>存储了每张图像的特征向量</td>
+<td><code>label_path</code></td>
+<td>数据标注文件路径,数据组织方式参考<a href="#2.3-构建索引库的数据组织方式">2.3节 构建索引库的数据组织方式</a></td>
 <td><code>str</code></td>
 <td>无</td>
 </tr>
 </tbody>
 </table>
+
+索引库对象 `index` 支持 `save` 方法,用于将索引库保存到磁盘:
+
+<table>
+<thead>
+<tr>
+<th>参数</th>
+<th>参数说明</th>
+<th>参数类型</th>
+<th>默认值</th>
+</tr>
+</thead>
+<tbody>
+<tr>
+<td><code>save_path</code></td>
+<td>索引库的保存路径,如<code>drink.index</code>。</td>
+<td><code>str</code></td>
+<td>无</td>
+</tr>
+</tbody>
+</table>
+
 (3)调用通用图像识别产线对象的 `predict` 方法进行推理预测:`predict` 方法参数为 `input`,用于输入待预测数据,支持多种输入方式,具体示例如下:
 
 <table>
@@ -224,7 +247,7 @@ for res in output:
 </tr>
 </tbody>
 </table>
-另外,`predict`方法支持参数`index_dir`用于设置索库:
+另外,`predict`方法支持参数`index`用于设置索库:
 <table>
 <thead>
 <tr>
@@ -234,8 +257,8 @@ for res in output:
 </thead>
 <tbody>
 <tr>
-<td><code>index_dir</code></td>
-<td>产线推理预测所用的检索库文件所在的目录,如不传入该参数,则默认使用在<code>create_pipeline()</code>中通过参数<code>index_dir</code>指定的索库。</td>
+<td><code>index</code></td>
+<td>产线推理预测所用的索引库文件路径或是索引库对象,如不传入该参数,则默认使用在<code>create_pipeline()</code>中通过参数<code>index</code>指定的索库。</td>
 </tr>
 </tbody>
 </table>
@@ -275,7 +298,7 @@ for res in output:
 
 ```python
 from paddlex import create_pipeline
-pipeline = create_pipeline(pipeline="./my_path/PP-ShiTuV2.yaml", index_dir="index_dir")
+pipeline = create_pipeline(pipeline="./my_path/PP-ShiTuV2.yaml", index="drink.index")
 
 output = pipeline.predict("./drink_dataset_v2.0/test_images/")
 for res in output:
@@ -284,17 +307,18 @@ for res in output:
 ```
 
 
-#### 2.2.3 特征库的添加和删除操作
+#### 2.2.3 索引库的添加和删除操作
 
-若您希望将更多的图像添加到特征库中,则可以调用 `append_index` 方法;删除图像特征,则可以调用 `remove_index` 方法。
+若您希望将更多的图像添加到索引库中,则可以调用 `append_index` 方法;删除图像特征,则可以调用 `remove_index` 方法。
 
 ```python
 from paddlex import create_pipeline
 
 pipeline = create_pipeline("PP-ShiTuV2")
-pipeline.build_index(data_root="drink_dataset_v2.0/", index_dir="index_dir", index_type="IVF")
-pipeline.append_index(data_root="drink_dataset_v2.0/", index_dir="index_dir", index_type="IVF")
-pipeline.remove_index(data_root="drink_dataset_v2.0/", index_dir="index_dir", index_type="IVF")
+index_data = pipeline.build_index(data_root="drink_dataset_v2.0/", label_path="drink_dataset_v2.0/gallery.txt", index="drink.index", index_type="IVF")
+index_data = pipeline.append_index(data_root="drink_dataset_v2.0/", label_path="drink_dataset_v2.0/gallery.txt", index="drink.index", index_type="IVF")
+index_data = pipeline.remove_index(data_root="drink_dataset_v2.0/", label_path="drink_dataset_v2.0/gallery.txt", index="drink.index", index_type="IVF")
+index_data.save("drink.index")
 ```
 
 上述方法参数说明如下:
@@ -310,13 +334,19 @@ pipeline.remove_index(data_root="drink_dataset_v2.0/", index_dir="index_dir", in
 <tbody>
 <tr>
 <td><code>data_root</code></td>
-<td>要添加的数据集的根目录。数据组织方式与构建特征库时相同,参考<a href="#2.3-构建特征库的数据组织方式">2.3节 构建特征库的数据组织方式</a></td>
+<td>要添加的数据集的根目录。数据组织方式与构建索引库时相同,参考<a href="#2.3-构建索引库的数据组织方式">2.3节 构建索引库的数据组织方式</a></td>
+<td><code>str</code></td>
+<td>无</td>
+</tr>
+<tr>
+<td><code>label_path</code></td>
+<td>要添加的数据集标注文件的路径。数据组织方式与构建索引库时相同,参考<a href="#2.3-构建索引库的数据组织方式">2.3节 构建索引库的数据组织方式</a></td>
 <td><code>str</code></td>
 <td>无</td>
 </tr>
 <tr>
-<td><code>index_dir</code></td>
-<td>特征库的存储目录,在 <code>append_index</code> 和 <code>remove_index</code> 中,同时也是被修改(或删除)的特征库的路径,</td>
+<td><code>index</code></td>
+<td>索引库文件的路径,或是索引库对象,仅在 <code>append_index</code> 和 <code>remove_index</code> 中有效,表示待修改的索引库。</td>
 <td><code>str</code></td>
 <td>无</td>
 </tr>
@@ -334,15 +364,15 @@ pipeline.remove_index(data_root="drink_dataset_v2.0/", index_dir="index_dir", in
 </tr>
 </tbody>
 </table>
-### 2.3 构建特征库的数据组织方式
+### 2.3 构建索引库的数据组织方式
 
-PaddleX 的通用图像识别产线示例需要使用预先构建好的特征库进行特征检索。如果您希望用私有数据构建特征向量库,则需要按照如下方式组织数据:
+PaddleX 的通用图像识别产线示例需要使用预先构建好的索引库进行特征检索。如果您希望用私有数据构建索引库,则需要按照如下方式组织数据:
 
 ```bash
 data_root             # 数据集根目录,目录名称可以改变
 ├── images            # 图像的保存目录,目录名称可以改变
 │   │   ...
-└── gallery.txt       # 特征库数据集标注文件,文件名称不可改变。每行给出待检索图像路径和图像标签,使用空格分隔,内容举例: “0/0.jpg 脉动”
+└── gallery.txt       # 索引库数据集标注文件,文件名称不可改变。每行给出待检索图像路径和图像标签,使用空格分隔,内容举例: “0/0.jpg 脉动”
 ```
 
 ## 3. 开发集成/部署

+ 137 - 148
paddlex/inference/components/retrieval/faiss.py

@@ -22,6 +22,47 @@ from ....utils import logging
 from ..base import BaseComponent
 
 
+class IndexData:
+    def __init__(self, index, id_map):
+        self._index = index
+        self._id_map = id_map
+
+    @property
+    def index(self):
+        return self._index
+
+    @property
+    def index_bytes(self):
+        return faiss.serialize_index(self._index)
+
+    @property
+    def id_map(self):
+        return self._id_map
+
+    def save(self, save_path):
+        index_data = {
+            "index_bytes": self.index_bytes,
+            "id_map": self.id_map,
+        }
+        with open(save_path, "wb") as fd:
+            pickle.dump(index_data, fd)
+
+    @classmethod
+    def load(self, index):
+        if isinstance(index, str):
+            with open(index, "rb") as fd:
+                index_data = pickle.load(fd)
+            index_ = faiss.deserialize_index(index_data["index_bytes"])
+            id_map = index_data["id_map"]
+            assert index_.ntotal == len(
+                id_map
+            ), "data number in index is not equal in in id_map"
+            return index_, id_map
+        else:
+            assert isinstance(index, IndexData)
+            return index.index, index.id_map
+
+
 class FaissIndexer(BaseComponent):
 
     INPUT_KEYS = "feature"
@@ -33,9 +74,7 @@ class FaissIndexer(BaseComponent):
 
     def __init__(
         self,
-        index_bytes=None,
-        vector_path=None,
-        id_map=None,
+        index,
         metric_type="IP",
         return_k=1,
         score_thres=None,
@@ -44,19 +83,11 @@ class FaissIndexer(BaseComponent):
         super().__init__()
 
         if metric_type == "hamming":
-            if index_bytes is not None:
-                self._indexer = faiss.deserialize_index(index_bytes)
-            else:
-                self._indexer = faiss.read_index_binary(vector_path)
             self.hamming_radius = hamming_radius
         else:
-            if index_bytes is not None:
-                self._indexer = faiss.deserialize_index(index_bytes)
-            else:
-                self._indexer = faiss.read_index(vector_path)
             self.score_thres = score_thres
 
-        self.id_map = id_map
+        self._indexer, self.id_map = IndexData.load(index)
         self.metric_type = metric_type
         self.return_k = return_k
 
@@ -82,192 +113,155 @@ class FaissIndexer(BaseComponent):
 
 class FaissBuilder:
 
-    SUPPORT_MODE = ("new", "remove", "append")
     SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
     SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
     BINARY_METRIC_TYPE = ("hamming",)
     BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
 
-    def __init__(self, predict, mode="new", index_type="HNSW32", metric_type="IP"):
-        super().__init__()
-        assert (
-            mode in self.SUPPORT_MODE
-        ), f"Supported modes only: {self.SUPPORT_MODE}. But received {mode}!"
-        assert (
-            metric_type in self.SUPPORT_METRIC_TYPE
-        ), f"Supported metric types only: {self.SUPPORT_METRIC_TYPE}!"
-        assert (
-            index_type in self.SUPPORT_INDEX_TYPE
-        ), f"Supported index types only: {self.SUPPORT_INDEX_TYPE}!"
-
-        self._predict = predict
-        self._mode = mode
-        self._metric_type = metric_type
-        self._index_type = index_type
-
-    def _get_index_type(self, num=None):
+    @classmethod
+    def _get_index_type(cls, metric_type, index_type, num=None):
         # if IVF method, cal ivf number automaticlly
-        if self._index_type == "IVF":
-            index_type = self._index_type + str(min(int(num // 8), 65536))
-            if self._metric_type in self.BINARY_METRIC_TYPE:
+        if index_type == "IVF":
+            index_type = index_type + str(min(int(num // 8), 65536))
+            if metric_type in cls.BINARY_METRIC_TYPE:
                 index_type += ",BFlat"
             else:
                 index_type += ",Flat"
 
         # for binary index, add B at head of index_type
-        if self._metric_type in self.BINARY_METRIC_TYPE:
+        if metric_type in cls.BINARY_METRIC_TYPE:
             assert (
-                self._index_type in self.BINARY_SUPPORT_INDEX_TYPE
-            ), f"The metric type({self._metric_type}) only support {self.BINARY_SUPPORT_INDEX_TYPE} index types!"
+                index_type in cls.BINARY_SUPPORT_INDEX_TYPE
+            ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
             index_type = "B" + index_type
 
-        if self._index_type == "HNSW32":
+        if index_type == "HNSW32":
             logging.warning("The HNSW32 method dose not support 'remove' operation")
             index_type = "HNSW32"
 
-        if self._index_type == "Flat":
+        if index_type == "Flat":
             index_type = "Flat"
 
         return index_type
 
-    def _get_metric_type(self):
-        if self._metric_type == "hamming":
+    @classmethod
+    def _get_metric_type(cls, metric_type):
+        if metric_type == "hamming":
             return faiss.METRIC_Hamming
-        elif self._metric_type == "jaccard":
+        elif metric_type == "jaccard":
             return faiss.METRIC_Jaccard
-        elif self._metric_type == "IP":
+        elif metric_type == "IP":
             return faiss.METRIC_INNER_PRODUCT
-        elif self._metric_type == "L2":
+        elif metric_type == "L2":
             return faiss.METRIC_L2
 
+    @classmethod
     def build(
-        self,
-        label_file,
-        image_root,
-        index_dir=None,
+        cls,
+        gallery_imgs,
+        gallery_label,
+        predict_func,
+        metric_type="IP",
+        index_type="HNSW32",
     ):
-        file_list, gallery_docs = get_file_list(label_file, image_root)
+        assert (
+            metric_type in cls.SUPPORT_METRIC_TYPE
+        ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
 
-        features = [res["feature"] for res in self._predict(file_list)]
-        dtype = np.uint8 if self._metric_type in self.BINARY_METRIC_TYPE else np.float32
+        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
+
+        features = [res["feature"] for res in predict_func(gallery_list)]
+        dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
         features = np.array(features).astype(dtype)
         vector_num, vector_dim = features.shape
 
-        if self._metric_type in self.BINARY_METRIC_TYPE:
+        if metric_type in cls.BINARY_METRIC_TYPE:
             index = faiss.index_binary_factory(
                 vector_dim,
-                self._get_index_type(vector_num),
-                self._get_metric_type(),
+                cls._get_index_type(metric_type, index_type, vector_num),
+                cls._get_metric_type(metric_type),
             )
         else:
             index = faiss.index_factory(
                 vector_dim,
-                self._get_index_type(vector_num),
-                self._get_metric_type(),
+                cls._get_index_type(metric_type, index_type, vector_num),
+                cls._get_metric_type(metric_type),
             )
             index = faiss.IndexIDMap2(index)
         ids = {}
 
         # calculate id for new data
-        index, ids = self._add_gallery(index, ids, features, gallery_docs)
-        if index_dir:
-            self._save_gallery(index, ids, index_dir)
-        return faiss.serialize_index(index), ids
+        index, ids = cls._add_gallery(
+            metric_type, index, ids, features, gallery_docs, mode="new"
+        )
+        return IndexData(index, ids)
 
+    @classmethod
     def remove(
-        self,
-        label_file,
-        image_root,
-        index_dir=None,
-        index_bytes=None,
-        vector_path=None,
-        id_map=None,
+        cls,
+        gallery_imgs,
+        gallery_label,
+        index,
+        index_type="HNSW32",
     ):
-        file_list, gallery_docs = get_file_list(label_file, image_root)
+        assert (
+            index_type in cls.SUPPORT_INDEX_TYPE
+        ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
 
-        if index_bytes is not None:
-            index = faiss.deserialize_index(index_bytes)
-            ids = id_map
-        else:
-            # load vector.index and id_map.pkl
-            index, ids = self._load_index(index_dir)
+        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
+        index, ids = IndexData.load(index)
 
-        if self._index_type == "HNSW32":
+        if index_type == "HNSW32":
             raise RuntimeError(
                 "The index_type: HNSW32 dose not support 'remove' operation"
             )
 
         # remove ids in id_map, remove index data in faiss index
-        index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
-        if index_dir:
-            self._save_gallery(index, ids, index_dir)
-        return faiss.serialize_index(index), ids
+        index, ids = cls._rm_id_in_gallery(index, ids, gallery_docs)
+        return IndexData(index, ids)
 
-    def append(
-        self,
-        label_file,
-        image_root,
-        index_dir=None,
-        index_bytes=None,
-        vector_path=None,
-        id_map=None,
-    ):
-        file_list, gallery_docs = get_file_list(label_file, image_root)
-        features = [res["feature"] for res in self._predict(file_list)]
-        dtype = np.uint8 if self._metric_type in self.BINARY_METRIC_TYPE else np.float32
+    @classmethod
+    def append(cls, gallery_imgs, gallery_label, predict_func, index, metric_type="IP"):
+        assert (
+            metric_type in cls.SUPPORT_METRIC_TYPE
+        ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
+
+        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
+        features = [res["feature"] for res in predict_func(gallery_list)]
+        dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
         features = np.array(features).astype(dtype)
 
-        if index_bytes is not None:
-            index = faiss.deserialize_index(index_bytes)
-            ids = id_map
-        else:
-            # load vector.index and id_map.pkl
-            index, ids = self._load_index(index_dir)
+        index, ids = IndexData.load(index)
 
         # calculate id for new data
-        index, ids = self._add_gallery(index, ids, features, gallery_docs)
-        if index_dir:
-            self._save_gallery(index, ids, index_dir)
-        return faiss.serialize_index(index), ids
-
-    def _load_index(self, index_dir):
-        assert os.path.join(
-            index_dir, "vector.index"
-        ), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
-            index_dir
-        )
-        assert os.path.join(
-            index_dir, "id_map.pkl"
-        ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
-            index_dir
+        index, ids = cls._add_gallery(
+            metric_type, index, ids, features, gallery_docs, mode="append"
         )
-        index = faiss.read_index(os.path.join(index_dir, "vector.index"))
-        with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
-            ids = pickle.load(fd)
-        assert index.ntotal == len(
-            ids.keys()
-        ), "data number in index is not equal in in id_map"
-        return index, ids
+        return IndexData(index, ids)
 
-    def _add_gallery(self, index, ids, gallery_features, gallery_docs):
+    @classmethod
+    def _add_gallery(
+        cls, metric_type, index, ids, gallery_features, gallery_docs, mode
+    ):
         start_id = max(ids.keys()) + 1 if ids else 0
         ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
 
         # only train when new index file
-        if self._mode == "new":
-            if self._metric_type in self.BINARY_METRIC_TYPE:
+        if mode == "new":
+            if metric_type in cls.BINARY_METRIC_TYPE:
                 index.add(gallery_features)
             else:
                 index.train(gallery_features)
 
-        if not self._metric_type in self.BINARY_METRIC_TYPE:
+        if not metric_type in cls.BINARY_METRIC_TYPE:
             index.add_with_ids(gallery_features, ids_now)
 
         for i, d in zip(list(ids_now), gallery_docs):
             ids[i] = d
         return index, ids
 
-    def _rm_id_in_galllery(self, index, ids, gallery_docs):
+    @classmethod
+    def _rm_id_in_gallery(cls, index, ids, gallery_docs):
         remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
         remove_ids = np.asarray(remove_ids)
         index.remove_ids(remove_ids)
@@ -276,28 +270,23 @@ class FaissBuilder:
 
         return index, ids
 
-    def _save_gallery(self, index, ids, index_dir):
-        Path(index_dir).mkdir(parents=True, exist_ok=True)
-        if self._metric_type in self.BINARY_METRIC_TYPE:
-            faiss.write_index_binary(index, os.path.join(index_dir, "vector.index"))
+    @classmethod
+    def get_gallery(cls, gallery_imgs, gallery_label, delimiter=" "):
+        if isinstance(gallery_label, str):
+            assert isinstance(gallery_imgs, str)
+            gallery_imgs = Path(gallery_imgs)
+            files = []
+            labels = []
+            lines = []
+            with open(gallery_label, "r", encoding="utf-8") as f:
+                lines = f.readlines()
+            for line in lines:
+                path, label = line.strip().split(delimiter)
+                file_path = gallery_imgs / path
+                files.append(file_path.as_posix())
+                labels.append(label)
+            return files, labels
         else:
-            faiss.write_index(index, os.path.join(index_dir, "vector.index"))
-
-        with open(os.path.join(index_dir, "id_map.pkl"), "wb") as fd:
-            pickle.dump(ids, fd)
-
-
-def get_file_list(data_file, root_dir, delimiter=" "):
-    root_dir = Path(root_dir)
-    files = []
-    labels = []
-    lines = []
-    with open(data_file, "r", encoding="utf-8") as f:
-        lines = f.readlines()
-    for line in lines:
-        path, label = line.strip().split(delimiter)
-        file_path = root_dir / path
-        files.append(file_path.as_posix())
-        labels.append(label)
-
-    return files, labels
+            assert isinstance(gallery_imgs, list)
+            assert isinstance(gallery_label, list)
+            return gallery_imgs, gallery_label

+ 37 - 85
paddlex/inference/pipelines/pp_shitu_v2.py

@@ -34,7 +34,7 @@ class ShiTuV2Pipeline(BasePipeline):
         rec_model,
         det_batch_size=1,
         rec_batch_size=1,
-        index_dir=None,
+        index=None,
         metric_type="IP",
         score_thres=None,
         hamming_radius=None,
@@ -51,32 +51,16 @@ class ShiTuV2Pipeline(BasePipeline):
             score_thres,
             hamming_radius,
         )
-        self._indexer = self._build_indexer(index_dir=index_dir) if index_dir else None
+        self._indexer = self._build_indexer(index=index) if index else None
 
-    def _build_indexer(self, index_bytes=None, id_map=None, index_dir=None):
-        if index_bytes is not None and id_map is not None:
-            return FaissIndexer(
-                index_bytes=index_bytes,
-                id_map=id_map,
-                metric_type=self._metric_type,
-                return_k=self._return_k,
-                score_thres=self._score_thres,
-                hamming_radius=self._hamming_radius,
-            )
-        else:
-            assert index_dir
-            vector_path = (Path(index_dir) / "vector.index").as_posix()
-            with open(Path(index_dir) / "id_map.pkl", "rb") as fd:
-                id_map = pickle.load(fd)
-
-            return FaissIndexer(
-                vector_path=vector_path,
-                id_map=id_map,
-                metric_type=self._metric_type,
-                return_k=self._return_k,
-                score_thres=self._score_thres,
-                hamming_radius=self._hamming_radius,
-            )
+    def _build_indexer(self, index):
+        return FaissIndexer(
+            index=index,
+            metric_type=self._metric_type,
+            return_k=self._return_k,
+            score_thres=self._score_thres,
+            hamming_radius=self._hamming_radius,
+        )
 
     def _build_predictor(self, det_model, rec_model):
         self.det_model = self._create(model=det_model)
@@ -93,13 +77,8 @@ class ShiTuV2Pipeline(BasePipeline):
             self.det_model.set_predictor(device=device)
             self.rec_model.set_predictor(device=device)
 
-    def predict(self, input, index_bytes=None, id_map=None, index_dir=None, **kwargs):
-        if index_bytes is not None or index_dir is not None:
-            indexer = self._build_indexer(
-                index_bytes=index_bytes, id_map=id_map, index_dir=index_dir
-            )
-        else:
-            indexer = self._indexer
+    def predict(self, input, index=None, **kwargs):
+        indexer = self._build_indexer(index) if index is not None else self._indexer
         assert indexer
         self.set_predictor(**kwargs)
         for det_res in self.det_model(input):
@@ -143,72 +122,45 @@ class ShiTuV2Pipeline(BasePipeline):
             )
         return ShiTuResult(single_img_res)
 
-    def _build_index(
+    def build_index(
         self,
-        data_root,
-        index_dir=None,
-        mode="new",
+        gallery_imgs,
+        gallery_label,
         metric_type="IP",
         index_type="HNSW32",
-        **kwargs,
+        **kwargs
     ):
-        self.set_predictor(**kwargs)
-        self._metric_type = metric_type if metric_type else self._metric_type
-        builder = FaissBuilder(
+        return FaissBuilder.build(
+            gallery_imgs,
+            gallery_label,
             self.rec_model.predict,
-            mode=mode,
-            metric_type=self._metric_type,
-            index_type=index_type,
-        )
-        if mode == "new":
-            index_bytes, id_map = builder.build(
-                Path(data_root) / "gallery.txt", data_root, index_dir
-            )
-        elif mode == "remove":
-            index_bytes, id_map = builder.remove(
-                Path(data_root) / "gallery.txt", data_root, index_dir
-            )
-        elif mode == "append":
-            index_bytes, id_map = builder.append(
-                Path(data_root) / "gallery.txt", data_root, index_dir
-            )
-        else:
-            raise Exception("`mode` only support `new`, `remove` and `append`.")
-
-        return index_bytes, id_map
-
-    def build_index(
-        self, data_root, index_dir=None, metric_type="IP", index_type="HNSW32", **kwargs
-    ):
-        return self._build_index(
-            data_root=data_root,
-            index_dir=index_dir,
-            mode="new",
             metric_type=metric_type,
             index_type=index_type,
-            **kwargs,
+            **kwargs
         )
 
     def remove_index(
-        self, data_root, index_dir=None, metric_type="IP", index_type="HNSW32", **kwargs
+        self, gallery_imgs, gallery_label, index, index_type="HNSW32", **kwargs
     ):
-        return self._build_index(
-            data_root=data_root,
-            index_dir=index_dir,
-            mode="remove",
-            metric_type=metric_type,
-            index_type=index_type,
-            **kwargs,
+        return FaissBuilder.remove(
+            gallery_imgs, gallery_label, index, index_type=index_type, **kwargs
         )
 
     def append_index(
-        self, data_root, index_dir=None, metric_type="IP", index_type="HNSW32", **kwargs
+        self,
+        gallery_imgs,
+        gallery_label,
+        index,
+        id_map=None,
+        metric_type="IP",
+        index_type="HNSW32",
+        **kwargs
     ):
-        return self._build_index(
-            data_root=data_root,
-            index_dir=index_dir,
-            mode="append",
+        return FaissBuilder.append(
+            gallery_imgs,
+            gallery_label,
+            self.rec_model.predict,
+            index,
             metric_type=metric_type,
-            index_type=index_type,
-            **kwargs,
+            **kwargs
         )

+ 1 - 1
paddlex/pipelines/PP-ShiTuV2.yaml

@@ -8,6 +8,6 @@ Pipeline:
   det_batch_size: 1
   rec_batch_size: 1
   device: gpu
-  index_dir: "./drink_dataset_v2.0/index/"
+  index: None
   score_thres: 0.5
   return_k: 5