fast_scnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.nn as nn
  15. import paddle.nn.functional as F
  16. import paddle
  17. from paddlex.paddleseg.cvlibs import manager
  18. from paddlex.paddleseg.models import layers
  19. from paddlex.paddleseg.utils import utils
  20. __all__ = ['FastSCNN']
  21. @manager.MODELS.add_component
  22. class FastSCNN(nn.Layer):
  23. """
  24. The FastSCNN implementation based on PaddlePaddle.
  25. As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps)
  26. even for high resolution images (1024x2048).
  27. The original article refers to
  28. Poudel, Rudra PK, et al. "Fast-scnn: Fast semantic segmentation network"
  29. (https://arxiv.org/pdf/1902.04502.pdf).
  30. Args:
  31. num_classes (int): The unique number of target classes.
  32. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss.
  33. If true, auxiliary loss will be added after LearningToDownsample module. Default: False.
  34. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  35. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.. Default: False.
  36. pretrained (str, optional): The path or url of pretrained model. Default: None.
  37. """
  38. def __init__(self,
  39. num_classes,
  40. enable_auxiliary_loss=True,
  41. align_corners=False,
  42. pretrained=None):
  43. super().__init__()
  44. self.learning_to_downsample = LearningToDownsample(32, 48, 64)
  45. self.global_feature_extractor = GlobalFeatureExtractor(
  46. in_channels=64,
  47. block_channels=[64, 96, 128],
  48. out_channels=128,
  49. expansion=6,
  50. num_blocks=[3, 3, 3],
  51. align_corners=True)
  52. self.feature_fusion = FeatureFusionModule(64, 128, 128, align_corners)
  53. self.classifier = Classifier(128, num_classes)
  54. if enable_auxiliary_loss:
  55. self.auxlayer = layers.AuxLayer(64, 32, num_classes)
  56. self.enable_auxiliary_loss = enable_auxiliary_loss
  57. self.align_corners = align_corners
  58. self.pretrained = pretrained
  59. self.init_weight()
  60. def forward(self, x):
  61. logit_list = []
  62. input_size = paddle.shape(x)[2:]
  63. higher_res_features = self.learning_to_downsample(x)
  64. x = self.global_feature_extractor(higher_res_features)
  65. x = self.feature_fusion(higher_res_features, x)
  66. logit = self.classifier(x)
  67. logit = F.interpolate(
  68. logit,
  69. input_size,
  70. mode='bilinear',
  71. align_corners=self.align_corners)
  72. logit_list.append(logit)
  73. if self.enable_auxiliary_loss:
  74. auxiliary_logit = self.auxlayer(higher_res_features)
  75. auxiliary_logit = F.interpolate(
  76. auxiliary_logit,
  77. input_size,
  78. mode='bilinear',
  79. align_corners=self.align_corners)
  80. logit_list.append(auxiliary_logit)
  81. return logit_list
  82. def init_weight(self):
  83. if self.pretrained is not None:
  84. utils.load_entire_model(self, self.pretrained)
  85. class LearningToDownsample(nn.Layer):
  86. """
  87. Learning to downsample module.
  88. This module consists of three downsampling blocks (one conv and two separable conv)
  89. Args:
  90. dw_channels1 (int, optional): The input channels of the first sep conv. Default: 32.
  91. dw_channels2 (int, optional): The input channels of the second sep conv. Default: 48.
  92. out_channels (int, optional): The output channels of LearningToDownsample module. Default: 64.
  93. """
  94. def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64):
  95. super(LearningToDownsample, self).__init__()
  96. self.conv_bn_relu = layers.ConvBNReLU(
  97. in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2)
  98. self.dsconv_bn_relu1 = layers.SeparableConvBNReLU(
  99. in_channels=dw_channels1,
  100. out_channels=dw_channels2,
  101. kernel_size=3,
  102. stride=2,
  103. padding=1)
  104. self.dsconv_bn_relu2 = layers.SeparableConvBNReLU(
  105. in_channels=dw_channels2,
  106. out_channels=out_channels,
  107. kernel_size=3,
  108. stride=2,
  109. padding=1)
  110. def forward(self, x):
  111. x = self.conv_bn_relu(x)
  112. x = self.dsconv_bn_relu1(x)
  113. x = self.dsconv_bn_relu2(x)
  114. return x
  115. class GlobalFeatureExtractor(nn.Layer):
  116. """
  117. Global feature extractor module.
  118. This module consists of three InvertedBottleneck blocks (like inverted residual introduced by MobileNetV2) and
  119. a PPModule (introduced by PSPNet).
  120. Args:
  121. in_channels (int): The number of input channels to the module.
  122. block_channels (tuple): A tuple represents output channels of each bottleneck block.
  123. out_channels (int): The number of output channels of the module. Default:
  124. expansion (int): The expansion factor in bottleneck.
  125. num_blocks (tuple): It indicates the repeat time of each bottleneck.
  126. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  127. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
  128. """
  129. def __init__(self, in_channels, block_channels, out_channels, expansion,
  130. num_blocks, align_corners):
  131. super(GlobalFeatureExtractor, self).__init__()
  132. self.bottleneck1 = self._make_layer(InvertedBottleneck, in_channels,
  133. block_channels[0], num_blocks[0],
  134. expansion, 2)
  135. self.bottleneck2 = self._make_layer(
  136. InvertedBottleneck, block_channels[0], block_channels[1],
  137. num_blocks[1], expansion, 2)
  138. self.bottleneck3 = self._make_layer(
  139. InvertedBottleneck, block_channels[1], block_channels[2],
  140. num_blocks[2], expansion, 1)
  141. self.ppm = layers.PPModule(
  142. block_channels[2],
  143. out_channels,
  144. bin_sizes=(1, 2, 3, 6),
  145. dim_reduction=True,
  146. align_corners=align_corners)
  147. def _make_layer(self,
  148. block,
  149. in_channels,
  150. out_channels,
  151. blocks,
  152. expansion=6,
  153. stride=1):
  154. layers = []
  155. layers.append(block(in_channels, out_channels, expansion, stride))
  156. for _ in range(1, blocks):
  157. layers.append(block(out_channels, out_channels, expansion, 1))
  158. return nn.Sequential(*layers)
  159. def forward(self, x):
  160. x = self.bottleneck1(x)
  161. x = self.bottleneck2(x)
  162. x = self.bottleneck3(x)
  163. x = self.ppm(x)
  164. return x
  165. class InvertedBottleneck(nn.Layer):
  166. """
  167. Single Inverted bottleneck implementation.
  168. Args:
  169. in_channels (int): The number of input channels to bottleneck block.
  170. out_channels (int): The number of output channels of bottleneck block.
  171. expansion (int, optional). The expansion factor in bottleneck. Default: 6.
  172. stride (int, optional). The stride used in depth-wise conv. Defalt: 2.
  173. """
  174. def __init__(self, in_channels, out_channels, expansion=6, stride=2):
  175. super().__init__()
  176. self.use_shortcut = stride == 1 and in_channels == out_channels
  177. expand_channels = in_channels * expansion
  178. self.block = nn.Sequential(
  179. # pw
  180. layers.ConvBNReLU(
  181. in_channels=in_channels,
  182. out_channels=expand_channels,
  183. kernel_size=1,
  184. bias_attr=False),
  185. # dw
  186. layers.ConvBNReLU(
  187. in_channels=expand_channels,
  188. out_channels=expand_channels,
  189. kernel_size=3,
  190. stride=stride,
  191. padding=1,
  192. groups=expand_channels,
  193. bias_attr=False),
  194. # pw-linear
  195. layers.ConvBN(
  196. in_channels=expand_channels,
  197. out_channels=out_channels,
  198. kernel_size=1,
  199. bias_attr=False))
  200. def forward(self, x):
  201. out = self.block(x)
  202. if self.use_shortcut:
  203. out = x + out
  204. return out
  205. class FeatureFusionModule(nn.Layer):
  206. """
  207. Feature Fusion Module Implementation.
  208. This module fuses high-resolution feature and low-resolution feature.
  209. Args:
  210. high_in_channels (int): The channels of high-resolution feature (output of LearningToDownsample).
  211. low_in_channels (int): The channels of low-resolution feature (output of GlobalFeatureExtractor).
  212. out_channels (int): The output channels of this module.
  213. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  214. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
  215. """
  216. def __init__(self, high_in_channels, low_in_channels, out_channels,
  217. align_corners):
  218. super().__init__()
  219. # Only depth-wise conv
  220. self.dwconv = layers.ConvBNReLU(
  221. in_channels=low_in_channels,
  222. out_channels=out_channels,
  223. kernel_size=3,
  224. padding=1,
  225. groups=128,
  226. bias_attr=False)
  227. self.conv_low_res = layers.ConvBN(out_channels, out_channels, 1)
  228. self.conv_high_res = layers.ConvBN(high_in_channels, out_channels, 1)
  229. self.align_corners = align_corners
  230. def forward(self, high_res_input, low_res_input):
  231. low_res_input = F.interpolate(
  232. low_res_input,
  233. paddle.shape(high_res_input)[2:],
  234. mode='bilinear',
  235. align_corners=self.align_corners)
  236. low_res_input = self.dwconv(low_res_input)
  237. low_res_input = self.conv_low_res(low_res_input)
  238. high_res_input = self.conv_high_res(high_res_input)
  239. x = high_res_input + low_res_input
  240. return F.relu(x)
  241. class Classifier(nn.Layer):
  242. """
  243. The Classifier module implementation.
  244. This module consists of two depth-wise conv and one conv.
  245. Args:
  246. input_channels (int): The input channels to this module.
  247. num_classes (int): The unique number of target classes.
  248. """
  249. def __init__(self, input_channels, num_classes):
  250. super().__init__()
  251. self.dsconv1 = layers.SeparableConvBNReLU(
  252. in_channels=input_channels,
  253. out_channels=input_channels,
  254. kernel_size=3,
  255. padding=1)
  256. self.dsconv2 = layers.SeparableConvBNReLU(
  257. in_channels=input_channels,
  258. out_channels=input_channels,
  259. kernel_size=3,
  260. padding=1)
  261. self.conv = nn.Conv2D(
  262. in_channels=input_channels,
  263. out_channels=num_classes,
  264. kernel_size=1)
  265. self.dropout = nn.Dropout(p=0.1) # dropout_prob
  266. def forward(self, x):
  267. x = self.dsconv1(x)
  268. x = self.dsconv2(x)
  269. x = self.dropout(x)
  270. x = self.conv(x)
  271. return x