Browse Source

Merge pull request #1059 from will-jl944/develop_jf

Add anchor clustering
FlyingQianMM 4 years ago
parent
commit
6c391d173b

+ 189 - 3
docs/apis/datasets.md

@@ -4,7 +4,9 @@
 
 * [ImageNet](#1)
 * [VOCDetection](#2)
+  * [cluster_yolo_anchor](#21)
 * [CocoDetection](#3)
+  * [cluster_yolo_anchor](#31)
 * [SegDataset](#4)
 
 ## <h2 id="1">paddlex.datasets.ImageNet</h2>
@@ -25,7 +27,7 @@ paddlex.datasets.ImageNet(data_dir, file_list, label_list, transforms=None, num_
 > > * **num_workers** (int|str):数据集中样本在预处理过程中的进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 
-## <h2 id="1">paddlex.datasets.VOCDetection</h2>
+## <h2 id="2">paddlex.datasets.VOCDetection</h2>
 > **用于目标检测模型**  
 ```python
 paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None, num_workers='auto', shuffle=False)
@@ -45,7 +47,101 @@ paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None,
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 
-## <h2 id="1">paddlex.datasets.CocoDetection</h2>
+### <h3 id="21">cluster_yolo_anchor</h3>
+
+```python
+cluster_yolo_anchor(num_anchors, image_size, cache=True, cache_path=None, iters=300, gen_iters=1000, thresh=.25)
+```
+
+> 分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+> > * **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+> > * **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+> > * **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+> > * **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+> > * **iters** (int):K-Means聚类算法迭代次数。
+> > * **gen_iters** (int):基因演算法迭代次数。
+> > * **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms,
+    shuffle=False)
+
+# 在训练集上聚类生成9个anchor
+anchors = train_dataset.cluster_yolo_anchor(num_anchors=9, image_size=608)
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```
+
+## <h2 id="3">paddlex.datasets.CocoDetection</h2>
 > **用于实例分割/目标检测模型**  
 ```python
 paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers='auto', shuffle=False)
@@ -64,7 +160,97 @@ paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers=
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 
-## <h2 id="1">paddlex.datasets.SegDataset</h2>
+### <h3 id="31">cluster_yolo_anchor</h3>
+
+```python
+cluster_yolo_anchor(num_anchors, image_size, cache=True, cache_path=None, iters=300, gen_iters=1000, thresh=.25)
+```
+
+> 分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+> > * **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+> > * **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+> > * **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+> > * **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+> > * **iters** (int):K-Means聚类算法迭代次数。
+> > * **gen_iters** (int):基因演算法迭代次数。
+> > * **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.CocoDetection(
+    data_dir='xiaoduxiong_ins_det/JPEGImages',
+    ann_file='xiaoduxiong_ins_det/train.json',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.CocoDetection(
+    data_dir='xiaoduxiong_ins_det/JPEGImages',
+    ann_file='xiaoduxiong_ins_det/val.json',
+    transforms=eval_transforms)
+
+# 在训练集上聚类生成9个anchor
+anchors = train_dataset.cluster_yolo_anchor(num_anchors=9, image_size=608)
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```
+
+## <h2 id="4">paddlex.datasets.SegDataset</h2>
 > **用于语义分割模型**  
 ```python
 paddlex.datasets.SegDataset(data_dir, file_list, label_list=None, transforms=None, num_workers='auto', shuffle=False)

+ 100 - 0
docs/apis/tools/anchor_clustering.md

@@ -0,0 +1,100 @@
+# YOLO系列模型anchor聚类
+
+YOLO系列模型均支持自定义anchor,我们提供的默认配置为在MS COCO检测数据集上聚类生成的anchor。用户可以使用在自定义数据集上聚类生成的anchor以提升模型在特定数据集上的的精度。
+
+## YOLOAnchorCluster
+
+```python
+class paddlex.tools.YOLOAnchorCluster(num_anchors, dataset, image_size, cache, cache_path=None, iters=300, gen_iters=1000, thresh=0.25)
+```
+分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+* **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+* **dataset** (paddlex.dataset):用于聚类生成anchor的检测数据集,支持`VOCDetection`和`CocoDetection`格式。
+* **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+* **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+* **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+* **iters** (int):K-Means聚类算法迭代次数。
+* **gen_iters** (int):基因演算法迭代次数。
+* **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms,
+    shuffle=False)
+
+# 在训练集上聚类生成9个anchor
+cluster = pdx.tools.YOLOAnchorCluster(num_anchors=9,
+                                      dataset=train_dataset,
+                                      image_size=608)
+anchors = cluster()
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```

+ 1 - 0
paddlex/cv/datasets/coco.py

@@ -59,6 +59,7 @@ class CocoDetection(VOCDetection):
             six.reraise(*sys.exc_info())
 
         super(VOCDetection, self).__init__()
+        self.data_dir = data_dir
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50

+ 40 - 0
paddlex/cv/datasets/voc.py

@@ -24,6 +24,7 @@ import xml.etree.ElementTree as ET
 from paddle.io import Dataset
 from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 from paddlex.cv.transforms import Decode, MixupImage
+from paddlex.tools import YOLOAnchorCluster
 
 
 class VOCDetection(Dataset):
@@ -58,6 +59,7 @@ class VOCDetection(Dataset):
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         super(VOCDetection, self).__init__()
+        self.data_dir = data_dir
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
@@ -326,6 +328,44 @@ class VOCDetection(Dataset):
     def set_epoch(self, epoch_id):
         self._epoch = epoch_id
 
+    def cluster_yolo_anchor(self,
+                            num_anchors,
+                            image_size,
+                            cache=True,
+                            cache_path=None,
+                            iters=300,
+                            gen_iters=1000,
+                            thresh=.25):
+        """
+        Cluster YOLO anchors.
+
+        Reference:
+            https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
+
+        Args:
+            num_anchors (int): number of clusters
+            image_size (list or int): [h, w], being an int means image height and image width are the same.
+            cache (bool): whether using cache
+            cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset.
+            iters (int, optional): iters of kmeans algorithm
+            gen_iters (int, optional): iters of genetic algorithm
+            threshold (float, optional): anchor scale threshold
+            verbose (bool, optional): whether print results
+        """
+        if cache_path is None:
+            cache_path = self.data_dir
+        cluster = YOLOAnchorCluster(
+            num_anchors=num_anchors,
+            dataset=self,
+            image_size=image_size,
+            cache=cache,
+            cache_path=cache_path,
+            iters=iters,
+            gen_iters=gen_iters,
+            thresh=thresh)
+        anchors = cluster()
+        return anchors
+
     def add_negative_samples(self, image_dir, empty_ratio=1):
         """将背景图片加入训练
 

+ 0 - 4
paddlex/cv/models/base.py

@@ -434,13 +434,11 @@ class BaseModel:
                             criterion='l1_norm',
                             save_dir='output'):
         """
-
         Args:
             dataset(paddlex.dataset): Dataset used for evaluation during sensitivity analysis.
             batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
             criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
             save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
-
         """
         if self.__class__.__name__ in ['FasterRCNN', 'MaskRCNN']:
             raise Exception("{} does not support pruning currently!".format(
@@ -476,12 +474,10 @@ class BaseModel:
 
     def prune(self, pruned_flops, save_dir=None):
         """
-
         Args:
             pruned_flops(float): Ratio of FLOPs to be pruned.
             save_dir(None or str, optional): If None, the pruned model will not be saved.
                 Otherwise, the pruned model will be saved at save_dir. Defaults to None.
-
         """
         if self.status == "Pruned":
             raise Exception(

+ 0 - 6
paddlex/cv/models/detector.py

@@ -192,7 +192,6 @@ class BaseDetector(BaseModel):
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 `pretrain_weights` can be set simultaneously. Defaults to None.
-
         """
         if self.status == 'Infer':
             logging.error(
@@ -340,7 +339,6 @@ class BaseDetector(BaseModel):
                 configuration will be used. Defaults to None.
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
                 from. If None, no training checkpoint will be resumed. Defaults to None.
-
         """
         self._prepare_qat(quant_config)
         self.train(
@@ -378,10 +376,8 @@ class BaseDetector(BaseModel):
             metric({'VOC', 'COCO', None}, optional):
                 Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
             return_details(bool, optional): Whether to return evaluation details. Defaults to False.
-
         Returns:
             collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
-
         """
 
         if metric is None:
@@ -477,7 +473,6 @@ class BaseDetector(BaseModel):
                 meaning all images to be predicted as a mini-batch.
             transforms(paddlex.transforms.Compose or None, optional):
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
-
         Returns:
             If img_file is a string or np.array, the result is a list of dict with key-value pairs:
             {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
@@ -487,7 +482,6 @@ class BaseDetector(BaseModel):
             bbox(list): bounding box in [x, y, w, h] format
             score(str): confidence
             mask(dict): Only for instance segmentation task. Mask of the object in RLE format
-
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")

+ 0 - 1
paddlex/cv/models/load_model.py

@@ -50,7 +50,6 @@ def load_model(model_dir, **params):
     Load saved model from a given directory.
     Args:
         model_dir(str): The directory where the model is saved.
-
     Returns:
         The model loaded from the directory.
     """

+ 0 - 3
paddlex/deploy.py

@@ -35,7 +35,6 @@ class Predictor(object):
                  max_trt_batch_size=1,
                  trt_precision_mode='float32'):
         """ 创建Paddle Predictor
-
             Args:
                 model_dir: 模型路径(必须是导出的部署或量化模型)
                 use_gpu: 是否使用gpu,默认False
@@ -189,7 +188,6 @@ class Predictor(object):
 
     def raw_predict(self, inputs):
         """ 接受预处理过后的数据进行预测
-
             Args:
                 inputs(dict): 预处理过后的数据
         """
@@ -242,7 +240,6 @@ class Predictor(object):
                 warmup_iters=0,
                 repeats=1):
         """ 图片预测
-
             Args:
                 img_file(List[np.ndarray or str], str or np.ndarray):
                     图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。

+ 1 - 0
paddlex/tools/__init__.py

@@ -14,3 +14,4 @@
 
 from .split import dataset_split
 from .convert import dataset_conversion
+from .anchor_clustering import YOLOAnchorCluster

+ 15 - 0
paddlex/tools/anchor_clustering/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .yolo_cluster import YOLOAnchorCluster

+ 177 - 0
paddlex/tools/anchor_clustering/yolo_cluster.py

@@ -0,0 +1,177 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import numpy as np
+from tqdm import tqdm
+from scipy.cluster.vq import kmeans
+from paddlex.utils import logging
+
+__all__ = ['YOLOAnchorCluster']
+
+
+class BaseAnchorCluster(object):
+    def __init__(self, num_anchors, cache, cache_path):
+        """
+        Base Anchor Cluster
+        Args:
+            num_anchors (int): number of clusters
+            cache (bool): whether using cache
+            cache_path (str): cache directory path
+        """
+        super(BaseAnchorCluster, self).__init__()
+        self.num_anchors = num_anchors
+        self.cache_path = cache_path
+        self.cache = cache
+
+    def print_result(self, centers):
+        raise NotImplementedError('%s.print_result is not available' %
+                                  self.__class__.__name__)
+
+    def get_whs(self):
+        whs_cache_path = os.path.join(self.cache_path, 'whs.npy')
+        shapes_cache_path = os.path.join(self.cache_path, 'shapes.npy')
+        if self.cache and os.path.exists(whs_cache_path) and os.path.exists(
+                shapes_cache_path):
+            self.whs = np.load(whs_cache_path)
+            self.shapes = np.load(shapes_cache_path)
+            return self.whs, self.shapes
+        whs = np.zeros((0, 2))
+        shapes = np.zeros((0, 2))
+        samples = copy.deepcopy(self.dataset.file_list)
+        for sample in tqdm(samples):
+            im_h, im_w = sample['image_shape']
+            bbox = sample['gt_bbox']
+            wh = bbox[:, 2:4] - bbox[:, 0:2]
+            wh = wh / np.array([[im_w, im_h]])
+            shape = np.ones_like(wh) * np.array([[im_w, im_h]])
+            whs = np.vstack((whs, wh))
+            shapes = np.vstack((shapes, shape))
+
+        if self.cache:
+            os.makedirs(self.cache_path, exist_ok=True)
+            np.save(whs_cache_path, whs)
+            np.save(shapes_cache_path, shapes)
+
+        self.whs = whs
+        self.shapes = shapes
+        return self.whs, self.shapes
+
+    def calc_anchors(self):
+        raise NotImplementedError('%s.calc_anchors is not available' %
+                                  self.__class__.__name__)
+
+    def __call__(self):
+        self.get_whs()
+        centers = self.calc_anchors()
+        return centers
+
+
+class YOLOAnchorCluster(BaseAnchorCluster):
+    def __init__(self,
+                 num_anchors,
+                 dataset,
+                 image_size,
+                 cache=True,
+                 cache_path=None,
+                 iters=300,
+                 gen_iters=1000,
+                 thresh=0.25):
+        """
+        YOLOv5 Anchor Cluster
+
+        Reference:
+            https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
+
+        Args:
+            num_anchors (int): number of clusters
+            dataset (DataSet): DataSet instance, VOC or COCO
+            image_size (list or int): [h, w], being an int means image height and image width are the same.
+            cache (bool): whether using cache。 Defaults to True.
+            cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
+            iters (int, optional): iters of kmeans algorithm. Defaults to 300.
+            gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.
+            thresh (float, optional): anchor scale threshold. Defaults to 0.25.
+        """
+        self.dataset = dataset
+        if cache_path is None:
+            cache_path = self.dataset.data_dir
+        if isinstance(image_size, int):
+            image_size = [image_size] * 2
+        self.image_size = image_size
+        self.iters = iters
+        self.gen_iters = gen_iters
+        self.thresh = thresh
+        super(YOLOAnchorCluster, self).__init__(num_anchors, cache, cache_path)
+
+    def print_result(self, centers):
+        whs = self.whs
+        x, best = self.metric(whs, centers)
+        bpr, aat = (best > self.thresh).mean(), (
+            x > self.thresh).mean() * self.num_anchors
+        logging.info(
+            'thresh=%.2f: %.4f best possible recall, %.2f anchors past thr' %
+            (self.thresh, bpr, aat))
+        logging.info(
+            'n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thresh=%.3f-mean: '
+            % (self.num_anchors, self.image_size, x.mean(), best.mean(),
+               x[x > self.thresh].mean()))
+        logging.info('%d anchor cluster result: [w, h]' % self.num_anchors)
+        for w, h in centers:
+            logging.info('[%d, %d]' % (w, h))
+
+    def metric(self, whs, centers):
+        r = whs[:, None] / centers[None]
+        x = np.minimum(r, 1. / r).min(2)
+        return x, x.max(1)
+
+    def fitness(self, whs, centers):
+        _, best = self.metric(whs, centers)
+        return (best * (best > self.thresh)).mean()
+
+    def calc_anchors(self):
+        self.whs = self.whs * self.shapes / self.shapes.max(
+            1, keepdims=True) * np.array([self.image_size[::-1]])
+        wh0 = self.whs
+        i = (wh0 < 3.0).any(1).sum()
+        if i:
+            logging.warning('Extremely small objects found. %d of %d '
+                            'labels are < 3 pixels in width or height' %
+                            (i, len(wh0)))
+
+        wh = wh0[(wh0 >= 2.0).any(1)]
+        logging.info('Running kmeans for %g anchors on %g points...' %
+                     (self.num_anchors, len(wh)))
+        s = wh.std(0)
+        centers, dist = kmeans(wh / s, self.num_anchors, iter=self.iters)
+        centers *= s
+
+        f, sh, mp, s = self.fitness(wh, centers), centers.shape, 0.9, 0.1
+        pbar = tqdm(
+            range(self.gen_iters),
+            desc='Evolving anchors with Genetic Algorithm')
+        for _ in pbar:
+            v = np.ones(sh)
+            while (v == 1).all():
+                v = ((np.random.random(sh) < mp) * np.random.random() *
+                     np.random.randn(*sh) * s + 1).clip(0.3, 3.0)
+            new_centers = (centers.copy() * v).clip(min=2.0)
+            new_f = self.fitness(wh, new_centers)
+            if new_f > f:
+                f, centers = new_f, new_centers.copy()
+                pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
+
+        centers = np.round(centers[np.argsort(centers.prod(1))]).astype(int)
+        return centers