Browse Source

Merge pull request #1059 from will-jl944/develop_jf

Add anchor clustering
FlyingQianMM 4 years ago
parent
commit
6c391d173b

+ 189 - 3
docs/apis/datasets.md

@@ -4,7 +4,9 @@
 
 
 * [ImageNet](#1)
 * [ImageNet](#1)
 * [VOCDetection](#2)
 * [VOCDetection](#2)
+  * [cluster_yolo_anchor](#21)
 * [CocoDetection](#3)
 * [CocoDetection](#3)
+  * [cluster_yolo_anchor](#31)
 * [SegDataset](#4)
 * [SegDataset](#4)
 
 
 ## <h2 id="1">paddlex.datasets.ImageNet</h2>
 ## <h2 id="1">paddlex.datasets.ImageNet</h2>
@@ -25,7 +27,7 @@ paddlex.datasets.ImageNet(data_dir, file_list, label_list, transforms=None, num_
 > > * **num_workers** (int|str):数据集中样本在预处理过程中的进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > > * **num_workers** (int|str):数据集中样本在预处理过程中的进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 
 
-## <h2 id="1">paddlex.datasets.VOCDetection</h2>
+## <h2 id="2">paddlex.datasets.VOCDetection</h2>
 > **用于目标检测模型**  
 > **用于目标检测模型**  
 ```python
 ```python
 paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None, num_workers='auto', shuffle=False)
 paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None, num_workers='auto', shuffle=False)
@@ -45,7 +47,101 @@ paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None,
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 
 
-## <h2 id="1">paddlex.datasets.CocoDetection</h2>
+### <h3 id="21">cluster_yolo_anchor</h3>
+
+```python
+cluster_yolo_anchor(num_anchors, image_size, cache=True, cache_path=None, iters=300, gen_iters=1000, thresh=.25)
+```
+
+> 分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+> > * **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+> > * **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+> > * **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+> > * **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+> > * **iters** (int):K-Means聚类算法迭代次数。
+> > * **gen_iters** (int):基因演算法迭代次数。
+> > * **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms,
+    shuffle=False)
+
+# 在训练集上聚类生成9个anchor
+anchors = train_dataset.cluster_yolo_anchor(num_anchors=9, image_size=608)
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```
+
+## <h2 id="3">paddlex.datasets.CocoDetection</h2>
 > **用于实例分割/目标检测模型**  
 > **用于实例分割/目标检测模型**  
 ```python
 ```python
 paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers='auto', shuffle=False)
 paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers='auto', shuffle=False)
@@ -64,7 +160,97 @@ paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers=
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 > > * **allow_empty** (bool): 是否加载负样本。默认为False。
 
 
-## <h2 id="1">paddlex.datasets.SegDataset</h2>
+### <h3 id="31">cluster_yolo_anchor</h3>
+
+```python
+cluster_yolo_anchor(num_anchors, image_size, cache=True, cache_path=None, iters=300, gen_iters=1000, thresh=.25)
+```
+
+> 分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+> > * **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+> > * **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+> > * **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+> > * **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+> > * **iters** (int):K-Means聚类算法迭代次数。
+> > * **gen_iters** (int):基因演算法迭代次数。
+> > * **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.CocoDetection(
+    data_dir='xiaoduxiong_ins_det/JPEGImages',
+    ann_file='xiaoduxiong_ins_det/train.json',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.CocoDetection(
+    data_dir='xiaoduxiong_ins_det/JPEGImages',
+    ann_file='xiaoduxiong_ins_det/val.json',
+    transforms=eval_transforms)
+
+# 在训练集上聚类生成9个anchor
+anchors = train_dataset.cluster_yolo_anchor(num_anchors=9, image_size=608)
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```
+
+## <h2 id="4">paddlex.datasets.SegDataset</h2>
 > **用于语义分割模型**  
 > **用于语义分割模型**  
 ```python
 ```python
 paddlex.datasets.SegDataset(data_dir, file_list, label_list=None, transforms=None, num_workers='auto', shuffle=False)
 paddlex.datasets.SegDataset(data_dir, file_list, label_list=None, transforms=None, num_workers='auto', shuffle=False)

+ 100 - 0
docs/apis/tools/anchor_clustering.md

@@ -0,0 +1,100 @@
+# YOLO系列模型anchor聚类
+
+YOLO系列模型均支持自定义anchor,我们提供的默认配置为在MS COCO检测数据集上聚类生成的anchor。用户可以使用在自定义数据集上聚类生成的anchor以提升模型在特定数据集上的的精度。
+
+## YOLOAnchorCluster
+
+```python
+class paddlex.tools.YOLOAnchorCluster(num_anchors, dataset, image_size, cache, cache_path=None, iters=300, gen_iters=1000, thresh=0.25)
+```
+分析数据集中所有图像的标签,聚类生成YOLO系列检测模型指定格式的anchor,返回结果按照由小到大排列。
+
+> **注解**
+>
+> 自定义YOLO系列模型的`anchor`需要同时指定`anchor_masks`参数。`anchor_masks`参数为一个二维的列表,其长度等于模型backbone获取到的特征图数量(对于PPYOLO的MobileNetV3和ResNet18_vd,特征图数量为2,其余情况为3)。列表中的每一个元素也为列表,代表对应特征图上所检测的anchor编号。
+> 以PPYOLO网络的默认参数`anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]`,`anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]`为例,代表在第一个特征图上检测尺度为`[116, 90], [156, 198], [373, 326]`的目标,在第二个特征图上检测尺度为`[30, 61], [62, 45], [59, 119]`的目标,以此类推。
+
+> **参数**
+>
+* **num_anchors** (int): 生成anchor的数量。PPYOLO,当backbone网络为MobileNetV3或ResNet18_vd时通常设置为6,其余情况通常设置为9。对于PPYOLOv2、PPYOLOTiny、YOLOv3,通常设置为9。
+* **dataset** (paddlex.dataset):用于聚类生成anchor的检测数据集,支持`VOCDetection`和`CocoDetection`格式。
+* **image_size** (List[int] or int):训练时网络输入的尺寸。如果为list,长度须为2,分别代表高和宽;如果为int,代表输入尺寸高和宽相同。
+* **cache** (bool): 是否使用缓存。聚类生成anchor需要遍历数据集统计所有真值框的尺寸以及所有图片的尺寸,较为耗时。如果为True,会将真值框尺寸信息以及图片尺寸信息保存至`cache_path`路径下,若路径下已存缓存文件,则加载该缓存。如果为False,则不会保存或加载。默认为True。
+* **cache_path** (None or str):真值框尺寸信息以及图片尺寸信息缓存路径。 如果为None,则使用数据集所在的路径`data_dir`。默认为None。
+* **iters** (int):K-Means聚类算法迭代次数。
+* **gen_iters** (int):基因演算法迭代次数。
+* **thresh** (float):anchor尺寸与真值框尺寸之间比例的阈值。
+
+**代码示例**
+```python
+import paddlex as pdx
+from paddlex import transforms as T
+
+# 下载和解压昆虫检测数据集
+dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms,
+    shuffle=False)
+
+# 在训练集上聚类生成9个anchor
+cluster = pdx.tools.YOLOAnchorCluster(num_anchors=9,
+                                      dataset=train_dataset,
+                                      image_size=608)
+anchors = cluster()
+anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/train/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdx.det.PPYOLO(num_classes=num_classes,
+                       backbone='ResNet50_vd_dcn',
+                       anchors=anchors,
+                       anchor_masks=anchor_masks)
+
+# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)
+```

+ 1 - 0
paddlex/cv/datasets/coco.py

@@ -59,6 +59,7 @@ class CocoDetection(VOCDetection):
             six.reraise(*sys.exc_info())
             six.reraise(*sys.exc_info())
 
 
         super(VOCDetection, self).__init__()
         super(VOCDetection, self).__init__()
+        self.data_dir = data_dir
         self.data_fields = None
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
         self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50

+ 40 - 0
paddlex/cv/datasets/voc.py

@@ -24,6 +24,7 @@ import xml.etree.ElementTree as ET
 from paddle.io import Dataset
 from paddle.io import Dataset
 from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 from paddlex.cv.transforms import Decode, MixupImage
 from paddlex.cv.transforms import Decode, MixupImage
+from paddlex.tools import YOLOAnchorCluster
 
 
 
 
 class VOCDetection(Dataset):
 class VOCDetection(Dataset):
@@ -58,6 +59,7 @@ class VOCDetection(Dataset):
         matplotlib.use('Agg')
         matplotlib.use('Agg')
         from pycocotools.coco import COCO
         from pycocotools.coco import COCO
         super(VOCDetection, self).__init__()
         super(VOCDetection, self).__init__()
+        self.data_dir = data_dir
         self.data_fields = None
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
         self.transforms = copy.deepcopy(transforms)
         self.num_max_boxes = 50
         self.num_max_boxes = 50
@@ -326,6 +328,44 @@ class VOCDetection(Dataset):
     def set_epoch(self, epoch_id):
     def set_epoch(self, epoch_id):
         self._epoch = epoch_id
         self._epoch = epoch_id
 
 
+    def cluster_yolo_anchor(self,
+                            num_anchors,
+                            image_size,
+                            cache=True,
+                            cache_path=None,
+                            iters=300,
+                            gen_iters=1000,
+                            thresh=.25):
+        """
+        Cluster YOLO anchors.
+
+        Reference:
+            https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
+
+        Args:
+            num_anchors (int): number of clusters
+            image_size (list or int): [h, w], being an int means image height and image width are the same.
+            cache (bool): whether using cache
+            cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset.
+            iters (int, optional): iters of kmeans algorithm
+            gen_iters (int, optional): iters of genetic algorithm
+            threshold (float, optional): anchor scale threshold
+            verbose (bool, optional): whether print results
+        """
+        if cache_path is None:
+            cache_path = self.data_dir
+        cluster = YOLOAnchorCluster(
+            num_anchors=num_anchors,
+            dataset=self,
+            image_size=image_size,
+            cache=cache,
+            cache_path=cache_path,
+            iters=iters,
+            gen_iters=gen_iters,
+            thresh=thresh)
+        anchors = cluster()
+        return anchors
+
     def add_negative_samples(self, image_dir, empty_ratio=1):
     def add_negative_samples(self, image_dir, empty_ratio=1):
         """将背景图片加入训练
         """将背景图片加入训练
 
 

+ 0 - 4
paddlex/cv/models/base.py

@@ -434,13 +434,11 @@ class BaseModel:
                             criterion='l1_norm',
                             criterion='l1_norm',
                             save_dir='output'):
                             save_dir='output'):
         """
         """
-
         Args:
         Args:
             dataset(paddlex.dataset): Dataset used for evaluation during sensitivity analysis.
             dataset(paddlex.dataset): Dataset used for evaluation during sensitivity analysis.
             batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
             batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
             criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
             criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
             save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
             save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
-
         """
         """
         if self.__class__.__name__ in ['FasterRCNN', 'MaskRCNN']:
         if self.__class__.__name__ in ['FasterRCNN', 'MaskRCNN']:
             raise Exception("{} does not support pruning currently!".format(
             raise Exception("{} does not support pruning currently!".format(
@@ -476,12 +474,10 @@ class BaseModel:
 
 
     def prune(self, pruned_flops, save_dir=None):
     def prune(self, pruned_flops, save_dir=None):
         """
         """
-
         Args:
         Args:
             pruned_flops(float): Ratio of FLOPs to be pruned.
             pruned_flops(float): Ratio of FLOPs to be pruned.
             save_dir(None or str, optional): If None, the pruned model will not be saved.
             save_dir(None or str, optional): If None, the pruned model will not be saved.
                 Otherwise, the pruned model will be saved at save_dir. Defaults to None.
                 Otherwise, the pruned model will be saved at save_dir. Defaults to None.
-
         """
         """
         if self.status == "Pruned":
         if self.status == "Pruned":
             raise Exception(
             raise Exception(

+ 0 - 6
paddlex/cv/models/detector.py

@@ -192,7 +192,6 @@ class BaseDetector(BaseModel):
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 `pretrain_weights` can be set simultaneously. Defaults to None.
                 `pretrain_weights` can be set simultaneously. Defaults to None.
-
         """
         """
         if self.status == 'Infer':
         if self.status == 'Infer':
             logging.error(
             logging.error(
@@ -340,7 +339,6 @@ class BaseDetector(BaseModel):
                 configuration will be used. Defaults to None.
                 configuration will be used. Defaults to None.
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
             resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
                 from. If None, no training checkpoint will be resumed. Defaults to None.
                 from. If None, no training checkpoint will be resumed. Defaults to None.
-
         """
         """
         self._prepare_qat(quant_config)
         self._prepare_qat(quant_config)
         self.train(
         self.train(
@@ -378,10 +376,8 @@ class BaseDetector(BaseModel):
             metric({'VOC', 'COCO', None}, optional):
             metric({'VOC', 'COCO', None}, optional):
                 Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
                 Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
             return_details(bool, optional): Whether to return evaluation details. Defaults to False.
             return_details(bool, optional): Whether to return evaluation details. Defaults to False.
-
         Returns:
         Returns:
             collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
             collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
-
         """
         """
 
 
         if metric is None:
         if metric is None:
@@ -477,7 +473,6 @@ class BaseDetector(BaseModel):
                 meaning all images to be predicted as a mini-batch.
                 meaning all images to be predicted as a mini-batch.
             transforms(paddlex.transforms.Compose or None, optional):
             transforms(paddlex.transforms.Compose or None, optional):
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
-
         Returns:
         Returns:
             If img_file is a string or np.array, the result is a list of dict with key-value pairs:
             If img_file is a string or np.array, the result is a list of dict with key-value pairs:
             {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
             {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
@@ -487,7 +482,6 @@ class BaseDetector(BaseModel):
             bbox(list): bounding box in [x, y, w, h] format
             bbox(list): bounding box in [x, y, w, h] format
             score(str): confidence
             score(str): confidence
             mask(dict): Only for instance segmentation task. Mask of the object in RLE format
             mask(dict): Only for instance segmentation task. Mask of the object in RLE format
-
         """
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
         if transforms is None and not hasattr(self, 'test_transforms'):
             raise Exception("transforms need to be defined, now is None.")
             raise Exception("transforms need to be defined, now is None.")

+ 0 - 1
paddlex/cv/models/load_model.py

@@ -50,7 +50,6 @@ def load_model(model_dir, **params):
     Load saved model from a given directory.
     Load saved model from a given directory.
     Args:
     Args:
         model_dir(str): The directory where the model is saved.
         model_dir(str): The directory where the model is saved.
-
     Returns:
     Returns:
         The model loaded from the directory.
         The model loaded from the directory.
     """
     """

+ 0 - 3
paddlex/deploy.py

@@ -35,7 +35,6 @@ class Predictor(object):
                  max_trt_batch_size=1,
                  max_trt_batch_size=1,
                  trt_precision_mode='float32'):
                  trt_precision_mode='float32'):
         """ 创建Paddle Predictor
         """ 创建Paddle Predictor
-
             Args:
             Args:
                 model_dir: 模型路径(必须是导出的部署或量化模型)
                 model_dir: 模型路径(必须是导出的部署或量化模型)
                 use_gpu: 是否使用gpu,默认False
                 use_gpu: 是否使用gpu,默认False
@@ -189,7 +188,6 @@ class Predictor(object):
 
 
     def raw_predict(self, inputs):
     def raw_predict(self, inputs):
         """ 接受预处理过后的数据进行预测
         """ 接受预处理过后的数据进行预测
-
             Args:
             Args:
                 inputs(dict): 预处理过后的数据
                 inputs(dict): 预处理过后的数据
         """
         """
@@ -242,7 +240,6 @@ class Predictor(object):
                 warmup_iters=0,
                 warmup_iters=0,
                 repeats=1):
                 repeats=1):
         """ 图片预测
         """ 图片预测
-
             Args:
             Args:
                 img_file(List[np.ndarray or str], str or np.ndarray):
                 img_file(List[np.ndarray or str], str or np.ndarray):
                     图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
                     图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。

+ 1 - 0
paddlex/tools/__init__.py

@@ -14,3 +14,4 @@
 
 
 from .split import dataset_split
 from .split import dataset_split
 from .convert import dataset_conversion
 from .convert import dataset_conversion
+from .anchor_clustering import YOLOAnchorCluster

+ 15 - 0
paddlex/tools/anchor_clustering/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .yolo_cluster import YOLOAnchorCluster

+ 177 - 0
paddlex/tools/anchor_clustering/yolo_cluster.py

@@ -0,0 +1,177 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import numpy as np
+from tqdm import tqdm
+from scipy.cluster.vq import kmeans
+from paddlex.utils import logging
+
+__all__ = ['YOLOAnchorCluster']
+
+
+class BaseAnchorCluster(object):
+    def __init__(self, num_anchors, cache, cache_path):
+        """
+        Base Anchor Cluster
+        Args:
+            num_anchors (int): number of clusters
+            cache (bool): whether using cache
+            cache_path (str): cache directory path
+        """
+        super(BaseAnchorCluster, self).__init__()
+        self.num_anchors = num_anchors
+        self.cache_path = cache_path
+        self.cache = cache
+
+    def print_result(self, centers):
+        raise NotImplementedError('%s.print_result is not available' %
+                                  self.__class__.__name__)
+
+    def get_whs(self):
+        whs_cache_path = os.path.join(self.cache_path, 'whs.npy')
+        shapes_cache_path = os.path.join(self.cache_path, 'shapes.npy')
+        if self.cache and os.path.exists(whs_cache_path) and os.path.exists(
+                shapes_cache_path):
+            self.whs = np.load(whs_cache_path)
+            self.shapes = np.load(shapes_cache_path)
+            return self.whs, self.shapes
+        whs = np.zeros((0, 2))
+        shapes = np.zeros((0, 2))
+        samples = copy.deepcopy(self.dataset.file_list)
+        for sample in tqdm(samples):
+            im_h, im_w = sample['image_shape']
+            bbox = sample['gt_bbox']
+            wh = bbox[:, 2:4] - bbox[:, 0:2]
+            wh = wh / np.array([[im_w, im_h]])
+            shape = np.ones_like(wh) * np.array([[im_w, im_h]])
+            whs = np.vstack((whs, wh))
+            shapes = np.vstack((shapes, shape))
+
+        if self.cache:
+            os.makedirs(self.cache_path, exist_ok=True)
+            np.save(whs_cache_path, whs)
+            np.save(shapes_cache_path, shapes)
+
+        self.whs = whs
+        self.shapes = shapes
+        return self.whs, self.shapes
+
+    def calc_anchors(self):
+        raise NotImplementedError('%s.calc_anchors is not available' %
+                                  self.__class__.__name__)
+
+    def __call__(self):
+        self.get_whs()
+        centers = self.calc_anchors()
+        return centers
+
+
+class YOLOAnchorCluster(BaseAnchorCluster):
+    def __init__(self,
+                 num_anchors,
+                 dataset,
+                 image_size,
+                 cache=True,
+                 cache_path=None,
+                 iters=300,
+                 gen_iters=1000,
+                 thresh=0.25):
+        """
+        YOLOv5 Anchor Cluster
+
+        Reference:
+            https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
+
+        Args:
+            num_anchors (int): number of clusters
+            dataset (DataSet): DataSet instance, VOC or COCO
+            image_size (list or int): [h, w], being an int means image height and image width are the same.
+            cache (bool): whether using cache。 Defaults to True.
+            cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
+            iters (int, optional): iters of kmeans algorithm. Defaults to 300.
+            gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.
+            thresh (float, optional): anchor scale threshold. Defaults to 0.25.
+        """
+        self.dataset = dataset
+        if cache_path is None:
+            cache_path = self.dataset.data_dir
+        if isinstance(image_size, int):
+            image_size = [image_size] * 2
+        self.image_size = image_size
+        self.iters = iters
+        self.gen_iters = gen_iters
+        self.thresh = thresh
+        super(YOLOAnchorCluster, self).__init__(num_anchors, cache, cache_path)
+
+    def print_result(self, centers):
+        whs = self.whs
+        x, best = self.metric(whs, centers)
+        bpr, aat = (best > self.thresh).mean(), (
+            x > self.thresh).mean() * self.num_anchors
+        logging.info(
+            'thresh=%.2f: %.4f best possible recall, %.2f anchors past thr' %
+            (self.thresh, bpr, aat))
+        logging.info(
+            'n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thresh=%.3f-mean: '
+            % (self.num_anchors, self.image_size, x.mean(), best.mean(),
+               x[x > self.thresh].mean()))
+        logging.info('%d anchor cluster result: [w, h]' % self.num_anchors)
+        for w, h in centers:
+            logging.info('[%d, %d]' % (w, h))
+
+    def metric(self, whs, centers):
+        r = whs[:, None] / centers[None]
+        x = np.minimum(r, 1. / r).min(2)
+        return x, x.max(1)
+
+    def fitness(self, whs, centers):
+        _, best = self.metric(whs, centers)
+        return (best * (best > self.thresh)).mean()
+
+    def calc_anchors(self):
+        self.whs = self.whs * self.shapes / self.shapes.max(
+            1, keepdims=True) * np.array([self.image_size[::-1]])
+        wh0 = self.whs
+        i = (wh0 < 3.0).any(1).sum()
+        if i:
+            logging.warning('Extremely small objects found. %d of %d '
+                            'labels are < 3 pixels in width or height' %
+                            (i, len(wh0)))
+
+        wh = wh0[(wh0 >= 2.0).any(1)]
+        logging.info('Running kmeans for %g anchors on %g points...' %
+                     (self.num_anchors, len(wh)))
+        s = wh.std(0)
+        centers, dist = kmeans(wh / s, self.num_anchors, iter=self.iters)
+        centers *= s
+
+        f, sh, mp, s = self.fitness(wh, centers), centers.shape, 0.9, 0.1
+        pbar = tqdm(
+            range(self.gen_iters),
+            desc='Evolving anchors with Genetic Algorithm')
+        for _ in pbar:
+            v = np.ones(sh)
+            while (v == 1).all():
+                v = ((np.random.random(sh) < mp) * np.random.random() *
+                     np.random.randn(*sh) * s + 1).clip(0.3, 3.0)
+            new_centers = (centers.copy() * v).clip(min=2.0)
+            new_f = self.fitness(wh, new_centers)
+            if new_f > f:
+                f, centers = new_f, new_centers.copy()
+                pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
+
+        centers = np.round(centers[np.argsort(centers.prod(1))]).astype(int)
+        return centers