FlyingQianMM пре 5 година
родитељ
комит
cb477ebd12

+ 5 - 1
docs/apis/models/classification.md

@@ -35,7 +35,7 @@ train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, s
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
-> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
@@ -186,3 +186,7 @@ paddlex.cls.DenseNet161(num_classes=1000)
 paddlex.cls.DenseNet201(num_classes=1000)
 ```
 
+### HRNet_W18
+```python
+paddlex.cls.HRNet_W18(num_classes=1000)
+```

+ 3 - 3
docs/apis/models/detection.md

@@ -9,7 +9,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > 构建YOLOv3检测器。**注意在YOLOv3,num_classes不需要包含背景类,如目标包括human、dog两种,则num_classes设为2即可,这里与FasterRCNN/MaskRCNN有差别**
 
 > **参数**
-> 
+>
 > > - **num_classes** (int): 类别数。默认为80。
 > > - **backbone** (str): YOLOv3的backbone网络,取值范围为['DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large']。默认为'MobileNetV1'。
 > > - **anchors** (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
@@ -53,7 +53,7 @@ train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, sa
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
-> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
@@ -107,7 +107,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > **参数**
 
 > > - **num_classes** (int): 包含了背景类的类别数。默认为81。
-> > - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
+> > - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
 > > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。

+ 1 - 1
docs/apis/models/instance_segmentation.md

@@ -12,7 +12,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > **参数**
 
 > > - **num_classes** (int): 包含了背景类的类别数。默认为81。
-> > - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
+> > - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
 > > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
 > > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
 > > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。

+ 87 - 2
docs/apis/models/semantic_segmentation.md

@@ -47,7 +47,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
-> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
@@ -124,7 +124,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
 > > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
 > > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
 > > - **save_dir** (str): 模型保存路径。默认'output'
-> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认'COCO'。
+> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'COCO',则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认'COCO'。
 > > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
 > > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
 > > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
@@ -173,3 +173,88 @@ predict(self, im_file, transforms=None):
 > **返回值**
 > >
 > > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
+
+
+## HRNet类
+
+```python
+paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
+```
+
+> 构建HRNet分割器。
+
+> **参数**
+
+> > - **num_classes** (int): 类别数。
+> > - **width** (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]。
+> > - **use_bce_loss** (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
+> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
+> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
+> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
+
+### train 训练接口
+
+```python
+train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
+```
+
+> HRNet模型训练接口。
+
+> **参数**
+> >
+> > - **num_epochs** (int): 训练迭代轮数。
+> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
+> > - **train_batch_size** (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
+> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
+> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
+> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
+> > - **save_dir** (str): 模型保存路径。默认'output'
+> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet数据集上预训练的模型权重;若为None,则不使用预训练模型。默认'IMAGENET'。
+> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
+> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
+> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
+> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
+> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
+> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
+
+#### evaluate 评估接口
+
+```
+evaluate(self, eval_dataset, batch_size=1, epoch_id=None, return_details=False):
+```
+
+> HRNet模型评估接口。
+
+> **参数**
+> >
+> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
+> > - **batch_size** (int): 评估时的batch大小。默认1。
+> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
+> > - **return_details** (bool): 是否返回详细信息。默认False。
+
+> **返回值**
+> >
+> > - **dict**: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
+> >   'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
+> > - **tuple** (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
+> >   包含关键字:'confusion_matrix',表示评估的混淆矩阵。
+
+#### predict 预测接口
+
+```
+predict(self, im_file, transforms=None):
+```
+
+> HRNet模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`UNet.test_transforms`和`UNet.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义test_transforms传入给`predict`接口。
+
+> **参数**
+> >
+> > - **img_file** (str): 预测图像路径。
+> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
+
+> **返回值**
+> >
+> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。

+ 2 - 1
docs/appendix/model_zoo.md

@@ -27,6 +27,7 @@
 | DenseNet161|116.3MB  | 8.863       | 78.6     | 94.1     |
 | DenseNet201|  84.6MB   | 8.173       | 77.6     | 93.7     |
 | ShuffleNetV2 | 9.0MB   | 10.941        | 68.8     | 88.5     |
+| HRNet_W18 | 21.29MB | 7.368 (V100 GPU) | 76.9 | 93.4 |
 
 ## 目标检测模型
 
@@ -41,6 +42,7 @@
 |FasterRCNN-ResNet50_vd-FPN|168.7MB | 45.773 | 38.9 |
 |FasterRCNN-ResNet101-FPN| 251.7MB | 55.782 | 38.7 |
 |FasterRCNN-ResNet101_vd-FPN |252MB | 58.785 | 40.5 |
+|FasterRCNN-HRNet_W18-FPN |115.5MB | 57.11 | 36 |
 |YOLOv3-DarkNet53|252.4MB | 21.944 | 38.9 |
 |YOLOv3-MobileNetv1 |101.2MB | 12.771 | 29.3 |
 |YOLOv3-MobileNetv3|94.6MB | - | 31.6 |
@@ -49,4 +51,3 @@
 ## 实例分割模型
 
 > 表中模型相关指标均为在MSCOCO数据集上测试得到。
-

+ 2 - 2
paddlex/cv/models/base.py

@@ -31,8 +31,6 @@ from collections import OrderedDict
 from os import path as osp
 from paddle.fluid.framework import Program
 from .utils.pretrain_weights import get_pretrain_weights
-fluid.default_startup_program().random_seed = 1000
-fluid.default_main_program().random_seed = 1000
 
 
 def dict2str(dict_input):
@@ -200,6 +198,8 @@ class BaseAPI:
                 backbone = self.backbone
             else:
                 backbone = self.__class__.__name__
+                if backbone == "HRNet":
+                    backbone = backbone + "_W{}".format(self.width)
             pretrain_weights = get_pretrain_weights(
                 pretrain_weights, self.model_type, backbone, pretrain_dir)
         if startup_prog is None:

+ 3 - 1
paddlex/cv/models/mask_rcnn.py

@@ -195,7 +195,9 @@ class MaskRCNN(FasterRCNN):
         # 构建训练、验证、测试网络
         self.build_program()
         fuse_bn = True
-        if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']:
+        if self.with_fpn and self.backbone in [
+                'ResNet18', 'ResNet50', 'HRNet_W18'
+        ]:
             fuse_bn = False
         self.net_initialize(
             startup_prog=fluid.default_startup_program(),

+ 13 - 1
paddlex/cv/models/utils/pretrain_weights.py

@@ -58,6 +58,18 @@ image_pretrain = {
     'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
     'HRNet_W18':
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
+    'HRNet_W30':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
+    'HRNet_W32':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
+    'HRNet_W40':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
+    'HRNet_W48':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
+    'HRNet_W60':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
+    'HRNet_W64':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
 }
 
 coco_pretrain = {
@@ -87,7 +99,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
                 backbone = 'DetResNet50'
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
             backbone)
-        if backbone == "HRNet_W18":
+        if backbone.startswith("HRNet"):
             url = image_pretrain[backbone]
             fname = osp.split(url)[-1].split('.')[0]
             paddlex.utils.download_and_decompress(url, path=new_save_dir)

+ 7 - 6
paddlex/cv/nets/darknet.py

@@ -68,13 +68,14 @@ class DarkNet(object):
             bias_attr=False)
 
         bn_name = name + ".bn"
-
+        if self.num_classes:
+            regularizer = None
+        else:
+            regularizer = L2Decay(float(self.norm_decay))
         bn_param_attr = ParamAttr(
-            regularizer=L2Decay(float(self.norm_decay)),
-            name=bn_name + '.scale')
+            regularizer=regularizer, name=bn_name + '.scale')
         bn_bias_attr = ParamAttr(
-            regularizer=L2Decay(float(self.norm_decay)),
-            name=bn_name + '.offset')
+            regularizer=regularizer, name=bn_name + '.offset')
 
         out = fluid.layers.batch_norm(
             input=conv,
@@ -182,4 +183,4 @@ class DarkNet(object):
                 bias_attr=ParamAttr(name='fc_offset'))
             return out
 
-        return blocks
+        return blocks

+ 1 - 1
paddlex/cv/nets/densenet.py

@@ -173,4 +173,4 @@ class DenseNet(object):
             bn_ac_conv = fluid.layers.dropout(
                 x=bn_ac_conv, dropout_prob=dropout)
         bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1)
-        return bn_ac_conv
+        return bn_ac_conv

+ 1 - 2
paddlex/cv/nets/detection/mask_rcnn.py

@@ -204,8 +204,7 @@ class MaskRCNN(object):
                 bg_thresh_hi=self.bg_thresh_hi,
                 bg_thresh_lo=self.bg_thresh_lo,
                 bbox_reg_weights=self.bbox_reg_weights,
-                class_nums=self.num_classes,
-                use_random=self.rpn_head.use_random)
+                class_nums=self.num_classes)
 
             rois = outputs[0]
             labels_int32 = outputs[1]

+ 13 - 9
paddlex/cv/nets/mobilenet_v1.py

@@ -79,10 +79,14 @@ class MobileNetV1(object):
 
         bn_name = name + "_bn"
         norm_decay = self.norm_decay
+        if self.num_classes:
+            regularizer = None
+        else:
+            regularizer = L2Decay(norm_decay)
         bn_param_attr = ParamAttr(
-            regularizer=L2Decay(norm_decay), name=bn_name + '_scale')
+            regularizer=regularizer, name=bn_name + '_scale')
         bn_bias_attr = ParamAttr(
-            regularizer=L2Decay(norm_decay), name=bn_name + '_offset')
+            regularizer=regularizer, name=bn_name + '_offset')
         return fluid.layers.batch_norm(
             input=conv,
             act=act,
@@ -189,12 +193,12 @@ class MobileNetV1(object):
         if self.num_classes:
             out = fluid.layers.pool2d(
                 input=out, pool_type='avg', global_pooling=True)
-            output = fluid.layers.fc(
-                input=out,
-                size=self.num_classes,
-                param_attr=ParamAttr(
-                    initializer=fluid.initializer.MSRA(), name="fc7_weights"),
-                bias_attr=ParamAttr(name="fc7_offset"))
+            output = fluid.layers.fc(input=out,
+                                     size=self.num_classes,
+                                     param_attr=ParamAttr(
+                                         initializer=fluid.initializer.MSRA(),
+                                         name="fc7_weights"),
+                                     bias_attr=ParamAttr(name="fc7_offset"))
             return output
 
         if not self.with_extra_blocks:
@@ -213,4 +217,4 @@ class MobileNetV1(object):
         module17 = self._extra_block(module16, num_filters[3][0],
                                      num_filters[3][1], 1, 2,
                                      self.prefix_name + "conv7_4")
-        return module11, module13, module14, module15, module16, module17
+        return module11, module13, module14, module15, module16, module17

+ 129 - 113
paddlex/cv/nets/mobilenet_v3.py

@@ -31,6 +31,7 @@ class MobileNetV3():
         with_extra_blocks (bool): if extra blocks should be added.
         extra_block_filters (list): number of filter for each extra block.
     """
+
     def __init__(self,
                  scale=1.0,
                  model_name='small',
@@ -113,29 +114,36 @@ class MobileNetV3():
         lr_idx = self.curr_stage // self.lr_interval
         lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
         lr_mult = self.lr_mult_list[lr_idx]
-        conv_param_attr = ParamAttr(name=name + '_weights',
-                                    learning_rate=lr_mult,
-                                    regularizer=L2Decay(self.conv_decay))
-        conv = fluid.layers.conv2d(input=input,
-                                   num_filters=num_filters,
-                                   filter_size=filter_size,
-                                   stride=stride,
-                                   padding=padding,
-                                   groups=num_groups,
-                                   act=None,
-                                   use_cudnn=use_cudnn,
-                                   param_attr=conv_param_attr,
-                                   bias_attr=False)
+        if self.num_classes:
+            regularizer = None
+        else:
+            regularizer = L2Decay(self.conv_decay)
+        conv_param_attr = ParamAttr(
+            name=name + '_weights',
+            learning_rate=lr_mult,
+            regularizer=regularizer)
+        conv = fluid.layers.conv2d(
+            input=input,
+            num_filters=num_filters,
+            filter_size=filter_size,
+            stride=stride,
+            padding=padding,
+            groups=num_groups,
+            act=None,
+            use_cudnn=use_cudnn,
+            param_attr=conv_param_attr,
+            bias_attr=False)
         bn_name = name + '_bn'
-        bn_param_attr = ParamAttr(name=bn_name + "_scale",
-                                  regularizer=L2Decay(self.norm_decay))
-        bn_bias_attr = ParamAttr(name=bn_name + "_offset",
-                                 regularizer=L2Decay(self.norm_decay))
-        bn = fluid.layers.batch_norm(input=conv,
-                                     param_attr=bn_param_attr,
-                                     bias_attr=bn_bias_attr,
-                                     moving_mean_name=bn_name + '_mean',
-                                     moving_variance_name=bn_name + '_variance')
+        bn_param_attr = ParamAttr(
+            name=bn_name + "_scale", regularizer=L2Decay(self.norm_decay))
+        bn_bias_attr = ParamAttr(
+            name=bn_name + "_offset", regularizer=L2Decay(self.norm_decay))
+        bn = fluid.layers.batch_norm(
+            input=conv,
+            param_attr=bn_param_attr,
+            bias_attr=bn_bias_attr,
+            moving_mean_name=bn_name + '_mean',
+            moving_variance_name=bn_name + '_variance')
         if if_act:
             if act == 'relu':
                 bn = fluid.layers.relu(bn)
@@ -152,12 +160,10 @@ class MobileNetV3():
         lr_idx = self.curr_stage // self.lr_interval
         lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
         lr_mult = self.lr_mult_list[lr_idx]
-        
+
         num_mid_filter = int(num_out_filter // ratio)
-        pool = fluid.layers.pool2d(input=input,
-                                   pool_type='avg',
-                                   global_pooling=True,
-                                   use_cudnn=False)
+        pool = fluid.layers.pool2d(
+            input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
         conv1 = fluid.layers.conv2d(
             input=pool,
             filter_size=1,
@@ -191,43 +197,46 @@ class MobileNetV3():
                        use_se=False,
                        name=None):
         input_data = input
-        conv0 = self._conv_bn_layer(input=input,
-                                    filter_size=1,
-                                    num_filters=num_mid_filter,
-                                    stride=1,
-                                    padding=0,
-                                    if_act=True,
-                                    act=act,
-                                    name=name + '_expand')
+        conv0 = self._conv_bn_layer(
+            input=input,
+            filter_size=1,
+            num_filters=num_mid_filter,
+            stride=1,
+            padding=0,
+            if_act=True,
+            act=act,
+            name=name + '_expand')
         if self.block_stride == 16 and stride == 2:
             self.end_points.append(conv0)
-        conv1 = self._conv_bn_layer(input=conv0,
-                                    filter_size=filter_size,
-                                    num_filters=num_mid_filter,
-                                    stride=stride,
-                                    padding=int((filter_size - 1) // 2),
-                                    if_act=True,
-                                    act=act,
-                                    num_groups=num_mid_filter,
-                                    use_cudnn=False,
-                                    name=name + '_depthwise')
+        conv1 = self._conv_bn_layer(
+            input=conv0,
+            filter_size=filter_size,
+            num_filters=num_mid_filter,
+            stride=stride,
+            padding=int((filter_size - 1) // 2),
+            if_act=True,
+            act=act,
+            num_groups=num_mid_filter,
+            use_cudnn=False,
+            name=name + '_depthwise')
 
         if use_se:
-            conv1 = self._se_block(input=conv1,
-                                   num_out_filter=num_mid_filter,
-                                   name=name + '_se')
+            conv1 = self._se_block(
+                input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
 
-        conv2 = self._conv_bn_layer(input=conv1,
-                                    filter_size=1,
-                                    num_filters=num_out_filter,
-                                    stride=1,
-                                    padding=0,
-                                    if_act=False,
-                                    name=name + '_linear')
+        conv2 = self._conv_bn_layer(
+            input=conv1,
+            filter_size=1,
+            num_filters=num_out_filter,
+            stride=1,
+            padding=0,
+            if_act=False,
+            name=name + '_linear')
         if num_in_filter != num_out_filter or stride != 1:
             return conv2
         else:
-            return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None)
+            return fluid.layers.elementwise_add(
+                x=input_data, y=conv2, act=None)
 
     def _extra_block_dw(self,
                         input,
@@ -235,29 +244,32 @@ class MobileNetV3():
                         num_filters2,
                         stride,
                         name=None):
-        pointwise_conv = self._conv_bn_layer(input=input,
-                                             filter_size=1,
-                                             num_filters=int(num_filters1),
-                                             stride=1,
-                                             padding="SAME",
-                                             act='relu6',
-                                             name=name + "_extra1")
-        depthwise_conv = self._conv_bn_layer(input=pointwise_conv,
-                                             filter_size=3,
-                                             num_filters=int(num_filters2),
-                                             stride=stride,
-                                             padding="SAME",
-                                             num_groups=int(num_filters1),
-                                             act='relu6',
-                                             use_cudnn=False,
-                                             name=name + "_extra2_dw")
-        normal_conv = self._conv_bn_layer(input=depthwise_conv,
-                                          filter_size=1,
-                                          num_filters=int(num_filters2),
-                                          stride=1,
-                                          padding="SAME",
-                                          act='relu6',
-                                          name=name + "_extra2_sep")
+        pointwise_conv = self._conv_bn_layer(
+            input=input,
+            filter_size=1,
+            num_filters=int(num_filters1),
+            stride=1,
+            padding="SAME",
+            act='relu6',
+            name=name + "_extra1")
+        depthwise_conv = self._conv_bn_layer(
+            input=pointwise_conv,
+            filter_size=3,
+            num_filters=int(num_filters2),
+            stride=stride,
+            padding="SAME",
+            num_groups=int(num_filters1),
+            act='relu6',
+            use_cudnn=False,
+            name=name + "_extra2_dw")
+        normal_conv = self._conv_bn_layer(
+            input=depthwise_conv,
+            filter_size=1,
+            num_filters=int(num_filters2),
+            stride=1,
+            padding="SAME",
+            act='relu6',
+            name=name + "_extra2_sep")
         return normal_conv
 
     def __call__(self, input):
@@ -282,36 +294,39 @@ class MobileNetV3():
             self.block_stride *= layer_cfg[5]
             if layer_cfg[5] == 2:
                 blocks.append(conv)
-            conv = self._residual_unit(input=conv,
-                                       num_in_filter=inplanes,
-                                       num_mid_filter=int(scale * layer_cfg[1]),
-                                       num_out_filter=int(scale * layer_cfg[2]),
-                                       act=layer_cfg[4],
-                                       stride=layer_cfg[5],
-                                       filter_size=layer_cfg[0],
-                                       use_se=layer_cfg[3],
-                                       name='conv' + str(i + 2))
-            
+            conv = self._residual_unit(
+                input=conv,
+                num_in_filter=inplanes,
+                num_mid_filter=int(scale * layer_cfg[1]),
+                num_out_filter=int(scale * layer_cfg[2]),
+                act=layer_cfg[4],
+                stride=layer_cfg[5],
+                filter_size=layer_cfg[0],
+                use_se=layer_cfg[3],
+                name='conv' + str(i + 2))
+
             inplanes = int(scale * layer_cfg[2])
             i += 1
             self.curr_stage = i
         blocks.append(conv)
 
         if self.num_classes:
-            conv = self._conv_bn_layer(input=conv,
-                                       filter_size=1,
-                                       num_filters=int(scale * self.cls_ch_squeeze),
-                                       stride=1,
-                                       padding=0,
-                                       num_groups=1,
-                                       if_act=True,
-                                       act='hard_swish',
-                                       name='conv_last')
-            
-            conv = fluid.layers.pool2d(input=conv,
-                                       pool_type='avg',
-                                       global_pooling=True,
-                                       use_cudnn=False)
+            conv = self._conv_bn_layer(
+                input=conv,
+                filter_size=1,
+                num_filters=int(scale * self.cls_ch_squeeze),
+                stride=1,
+                padding=0,
+                num_groups=1,
+                if_act=True,
+                act='hard_swish',
+                name='conv_last')
+
+            conv = fluid.layers.pool2d(
+                input=conv,
+                pool_type='avg',
+                global_pooling=True,
+                use_cudnn=False)
             conv = fluid.layers.conv2d(
                 input=conv,
                 num_filters=self.cls_ch_expand,
@@ -326,22 +341,23 @@ class MobileNetV3():
             out = fluid.layers.fc(input=drop,
                                   size=self.num_classes,
                                   param_attr=ParamAttr(name='fc_weights'),
-                                  bias_attr=ParamAttr(name='fc_offset'))            
+                                  bias_attr=ParamAttr(name='fc_offset'))
             return out
 
         if not self.with_extra_blocks:
             return blocks
 
         # extra block
-        conv_extra = self._conv_bn_layer(conv,
-                                         filter_size=1,
-                                         num_filters=int(scale * cfg[-1][1]),
-                                         stride=1,
-                                         padding="SAME",
-                                         num_groups=1,
-                                         if_act=True,
-                                         act='hard_swish',
-                                         name='conv' + str(i + 2))
+        conv_extra = self._conv_bn_layer(
+            conv,
+            filter_size=1,
+            num_filters=int(scale * cfg[-1][1]),
+            stride=1,
+            padding="SAME",
+            num_groups=1,
+            if_act=True,
+            act='hard_swish',
+            name='conv' + str(i + 2))
         self.end_points.append(conv_extra)
         i += 1
         for block_filter in self.extra_block_filters:

+ 16 - 9
paddlex/cv/nets/resnet.py

@@ -135,8 +135,10 @@ class ResNet(object):
             filter_size=filter_size,
             stride=stride,
             padding=padding,
-            param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
-            bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
+            param_attr=ParamAttr(
+                initializer=Constant(0.0), name=name + ".w_0"),
+            bias_attr=ParamAttr(
+                initializer=Constant(0.0), name=name + ".b_0"),
             act=act,
             name=name)
         return out
@@ -151,7 +153,8 @@ class ResNet(object):
                    name=None,
                    dcn_v2=False,
                    use_lr_mult_list=False):
-        lr_mult = self.lr_mult_list[self.curr_stage] if use_lr_mult_list else 1.0
+        lr_mult = self.lr_mult_list[
+            self.curr_stage] if use_lr_mult_list else 1.0
         _name = self.prefix_name + name if self.prefix_name != '' else name
         if not dcn_v2:
             conv = fluid.layers.conv2d(
@@ -162,8 +165,8 @@ class ResNet(object):
                 padding=(filter_size - 1) // 2,
                 groups=groups,
                 act=None,
-                param_attr=ParamAttr(name=_name + "_weights",
-                                     learning_rate=lr_mult),
+                param_attr=ParamAttr(
+                    name=_name + "_weights", learning_rate=lr_mult),
                 bias_attr=False,
                 name=_name + '.conv2d.output.1')
         else:
@@ -202,14 +205,18 @@ class ResNet(object):
 
         norm_lr = 0. if self.freeze_norm else lr_mult
         norm_decay = self.norm_decay
+        if self.num_classes:
+            regularizer = None
+        else:
+            regularizer = L2Decay(norm_decay)
         pattr = ParamAttr(
             name=bn_name + '_scale',
             learning_rate=norm_lr,
-            regularizer=L2Decay(norm_decay))
+            regularizer=regularizer)
         battr = ParamAttr(
             name=bn_name + '_offset',
             learning_rate=norm_lr,
-            regularizer=L2Decay(norm_decay))
+            regularizer=regularizer)
 
         if self.norm_type in ['bn', 'sync_bn']:
             global_stats = True if self.freeze_norm else False
@@ -262,8 +269,8 @@ class ResNet(object):
                     pool_padding=0,
                     ceil_mode=True,
                     pool_type='avg')
-                return self._conv_norm(input, ch_out, 1, 1, name=name,
-                                      use_lr_mult_list=True)
+                return self._conv_norm(
+                    input, ch_out, 1, 1, name=name, use_lr_mult_list=True)
             return self._conv_norm(input, ch_out, 1, stride, name=name)
         else:
             return input

+ 1 - 0
paddlex/cv/nets/segmentation/__init__.py

@@ -14,5 +14,6 @@
 
 from .unet import UNet
 from .deeplabv3p import DeepLabv3p
+from .hrnet import HRNet
 from .model_utils import libs
 from .model_utils import loss

+ 1 - 0
paddlex/seg.py

@@ -17,5 +17,6 @@ from . import cv
 
 UNet = cv.models.UNet
 DeepLabv3p = cv.models.DeepLabv3p
+HRNet = cv.models.HRNet
 transforms = cv.transforms.seg_transforms
 visualize = cv.models.utils.visualize.visualize_segmentation

+ 50 - 0
tutorials/train/segmentation/hrnet.py

@@ -0,0 +1,50 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 下载和解压视盘分割数据集
+optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
+pdx.utils.download_and_decompress(optic_dataset, path='./')
+
+# 定义训练和验证时的transforms
+train_transforms = transforms.Compose([
+    transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
+    transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.ResizeByLong(long_size=512),
+    transforms.Padding(target_size=512), transforms.Normalize()
+])
+
+# 定义训练和验证所用的数据集
+train_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/train_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/val_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标
+# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+num_classes = len(train_dataset.labels)
+model = pdx.seg.HRNet(num_classes=num_classes, width=64)
+model.train(
+    num_epochs=20,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    save_dir='output/hrnet',
+    use_vdl=True)