| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import cv2
- import numpy as np
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddlex.paddleseg.cvlibs import manager
- from paddlex.paddleseg.models import layers
- from paddlex.paddleseg.models.backbones import resnet_vd
- from paddlex.paddleseg.models import deeplab
- from paddlex.paddleseg.utils import utils
- @manager.MODELS.add_component
- class GSCNN(nn.Layer):
- """
- The GSCNN implementation based on PaddlePaddle.
- The original article refers to
- Towaki Takikawa, et, al. "Gated-SCNN: Gated Shape CNNs for Semantic Segmentation"
- (https://arxiv.org/pdf/1907.05740.pdf)
- Args:
- num_classes (int): The unique number of target classes.
- backbone (paddle.nn.Layer): Backbone network, currently support Resnet50_vd/Resnet101_vd.
- backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
- Default: (0, 1, 2, 3).
- aspp_ratios (tuple, optional): The dilation rate using in ASSP module.
- If output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
- If output_stride=8, aspp_ratios is (1, 12, 24, 36).
- Default: (1, 6, 12, 18).
- aspp_out_channels (int, optional): The output channels of ASPP module. Default: 256.
- align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
- e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
- pretrained (str, optional): The path or url of pretrained model. Default: None.
- """
- def __init__(self,
- num_classes,
- backbone,
- backbone_indices=(0, 1, 2, 3),
- aspp_ratios=(1, 6, 12, 18),
- aspp_out_channels=256,
- align_corners=False,
- pretrained=None):
- super().__init__()
- self.backbone = backbone
- backbone_channels = self.backbone.feat_channels
- self.head = GSCNNHead(num_classes, backbone_indices, backbone_channels,
- aspp_ratios, aspp_out_channels, align_corners)
- self.align_corners = align_corners
- self.pretrained = pretrained
- self.init_weight()
- def forward(self, x):
- feat_list = self.backbone(x)
- logit_list = self.head(x, feat_list, self.backbone.conv1_logit)
- seg_logit, edge_logit = [
- F.interpolate(
- logit,
- x.shape[2:],
- mode='bilinear',
- align_corners=self.align_corners) for logit in logit_list
- ]
- return [seg_logit, (seg_logit, edge_logit), edge_logit, seg_logit]
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class GSCNNHead(nn.Layer):
- """
- The GSCNNHead implementation based on PaddlePaddle.
- Args:
- num_classes (int): The unique number of target classes.
- backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
- the first index will be taken as a low-level feature in Decoder component;
- the last one will be taken as input of ASPP component; the second to fourth
- will be taken as input for GCL component.
- Usually backbone consists of four downsampling stage, and return an output of
- each stage. If we set it as (0, 1, 2, 3), it means taking feature map of the first
- stage in backbone as low-level feature used in Decoder, feature map of the fourth
- stage as input of ASPP, and the feature map of the second to fourth stage as input of GCL.
- backbone_channels (tuple): The channels of output of backbone.
- aspp_ratios (tuple): The dilation rates using in ASSP module.
- aspp_out_channels (int): The output channels of ASPP module.
- align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
- is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
- """
- def __init__(self, num_classes, backbone_indices, backbone_channels,
- aspp_ratios, aspp_out_channels, align_corners):
- super().__init__()
- self.backbone_indices = backbone_indices
- self.align_corners = align_corners
- self.dsn1 = nn.Conv2D(
- backbone_channels[backbone_indices[1]], 1, kernel_size=1)
- self.dsn2 = nn.Conv2D(
- backbone_channels[backbone_indices[2]], 1, kernel_size=1)
- self.dsn3 = nn.Conv2D(
- backbone_channels[backbone_indices[3]], 1, kernel_size=1)
- self.res1 = resnet_vd.BasicBlock(64, 64, stride=1)
- self.d1 = nn.Conv2D(64, 32, kernel_size=1)
- self.gate1 = GatedSpatailConv2d(32, 32)
- self.res2 = resnet_vd.BasicBlock(32, 32, stride=1)
- self.d2 = nn.Conv2D(32, 16, kernel_size=1)
- self.gate2 = GatedSpatailConv2d(16, 16)
- self.res3 = resnet_vd.BasicBlock(16, 16, stride=1)
- self.d3 = nn.Conv2D(16, 8, kernel_size=1)
- self.gate3 = GatedSpatailConv2d(8, 8)
- self.fuse = nn.Conv2D(8, 1, kernel_size=1, bias_attr=False)
- self.cw = nn.Conv2D(2, 1, kernel_size=1, bias_attr=False)
- self.aspp = ASPPModule(
- aspp_ratios=aspp_ratios,
- in_channels=backbone_channels[-1],
- out_channels=aspp_out_channels,
- align_corners=self.align_corners,
- image_pooling=True)
- self.decoder = deeplab.Decoder(
- num_classes=num_classes,
- in_channels=backbone_channels[0],
- align_corners=self.align_corners)
- def forward(self, x, feat_list, s_input):
- input_shape = paddle.shape(x)
- m1f = F.interpolate(
- s_input,
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- l1, l2, l3 = [
- feat_list[self.backbone_indices[i]]
- for i in range(1, len(self.backbone_indices))
- ]
- s1 = F.interpolate(
- self.dsn1(l1),
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- s2 = F.interpolate(
- self.dsn2(l2),
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- s3 = F.interpolate(
- self.dsn3(l3),
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- # Get image gradient
- im_arr = x.numpy().transpose((0, 2, 3, 1))
- im_arr = ((im_arr * 0.5 + 0.5) * 255).astype(np.uint8)
- canny = np.zeros((input_shape[0], 1, input_shape[2], input_shape[3]))
- for i in range(input_shape[0]):
- canny[i] = cv2.Canny(im_arr[i], 10, 100)
- canny = canny / 255
- canny = paddle.to_tensor(canny).astype('float32')
- canny.stop_gradient = True
- cs = self.res1(m1f)
- cs = F.interpolate(
- cs,
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- cs = self.d1(cs)
- cs = self.gate1(cs, s1)
- cs = self.res2(cs)
- cs = F.interpolate(
- cs,
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- cs = self.d2(cs)
- cs = self.gate2(cs, s2)
- cs = self.res3(cs)
- cs = F.interpolate(
- cs,
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- cs = self.d3(cs)
- cs = self.gate3(cs, s3)
- cs = self.fuse(cs)
- cs = F.interpolate(
- cs,
- input_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- edge_out = F.sigmoid(cs) # Ouput of shape stream
- cat = paddle.concat([edge_out, canny], axis=1)
- acts = self.cw(cat)
- acts = F.sigmoid(acts) # Input of fusion module
- x = self.aspp(l3, acts)
- low_level_feat = feat_list[self.backbone_indices[0]]
- logit = self.decoder(x, low_level_feat)
- logit_list = [logit, edge_out]
- return logit_list
- class GatedSpatailConv2d(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias_attr=False):
- super().__init__()
- self._gate_conv = nn.Sequential(
- layers.SyncBatchNorm(in_channels + 1),
- nn.Conv2D(
- in_channels + 1, in_channels + 1, kernel_size=1),
- nn.ReLU(),
- nn.Conv2D(
- in_channels + 1, 1, kernel_size=1),
- layers.SyncBatchNorm(1),
- nn.Sigmoid())
- self.conv = nn.Conv2D(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias_attr=bias_attr)
- def forward(self, input_features, gating_features):
- cat = paddle.concat([input_features, gating_features], axis=1)
- alphas = self._gate_conv(cat)
- x = input_features * (alphas + 1)
- x = self.conv(x)
- return x
- class ASPPModule(nn.Layer):
- """
- Atrous Spatial Pyramid Pooling.
- Args:
- aspp_ratios (tuple): The dilation rate using in ASSP module.
- in_channels (int): The number of input channels.
- out_channels (int): The number of output channels.
- align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
- is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
- use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
- image_pooling (bool, optional): If augmented with image-level features. Default: False
- """
- def __init__(self,
- aspp_ratios,
- in_channels,
- out_channels,
- align_corners,
- use_sep_conv=False,
- image_pooling=False):
- super().__init__()
- self.align_corners = align_corners
- self.aspp_blocks = nn.LayerList()
- for ratio in aspp_ratios:
- if use_sep_conv and ratio > 1:
- conv_func = layers.SeparableConvBNReLU
- else:
- conv_func = layers.ConvBNReLU
- block = conv_func(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1 if ratio == 1 else 3,
- dilation=ratio,
- padding=0 if ratio == 1 else ratio)
- self.aspp_blocks.append(block)
- out_size = len(self.aspp_blocks)
- if image_pooling:
- self.global_avg_pool = nn.Sequential(
- nn.AdaptiveAvgPool2D(output_size=(1, 1)),
- layers.ConvBNReLU(
- in_channels, out_channels, kernel_size=1, bias_attr=False))
- out_size += 1
- self.image_pooling = image_pooling
- self.edge_conv = layers.ConvBNReLU(
- 1, out_channels, kernel_size=1, bias_attr=False)
- out_size += 1
- self.conv_bn_relu = layers.ConvBNReLU(
- in_channels=out_channels * out_size,
- out_channels=out_channels,
- kernel_size=1)
- self.dropout = nn.Dropout(p=0.1) # drop rate
- def forward(self, x, edge):
- outputs = []
- x_shape = paddle.shape(x)
- for block in self.aspp_blocks:
- y = block(x)
- y = F.interpolate(
- y,
- x_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- outputs.append(y)
- if self.image_pooling:
- img_avg = self.global_avg_pool(x)
- img_avg = F.interpolate(
- img_avg,
- x_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- outputs.append(img_avg)
- edge_features = F.interpolate(
- edge,
- size=x_shape[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- edge_features = self.edge_conv(edge_features)
- outputs.append(edge_features)
- x = paddle.concat(outputs, axis=1)
- x = self.conv_bn_relu(x)
- x = self.dropout(x)
- return x
|