sfnet.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class SFNet(nn.Layer):
  22. """
  23. The SFNet implementation based on PaddlePaddle.
  24. The original article refers to
  25. Li, Xiangtai, et al. "Semantic Flow for Fast and Accurate Scene Parsing"
  26. (https://arxiv.org/pdf/2002.10120.pdf).
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
  30. backbone_indices (tuple): Four values in the tuple indicate the indices of output of backbone.
  31. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: False.
  32. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  33. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  34. pretrained (str, optional): The path or url of pretrained model. Default: None.
  35. """
  36. def __init__(self,
  37. num_classes,
  38. backbone,
  39. backbone_indices,
  40. enable_auxiliary_loss=False,
  41. align_corners=False,
  42. pretrained=None):
  43. super(SFNet, self).__init__()
  44. self.backbone = backbone
  45. self.backbone_indices = backbone_indices
  46. self.in_channels = [
  47. self.backbone.feat_channels[i] for i in backbone_indices
  48. ]
  49. self.align_corners = align_corners
  50. self.pretrained = pretrained
  51. self.enable_auxiliary_loss = enable_auxiliary_loss
  52. if self.backbone.layers == 18:
  53. fpn_dim = 128
  54. inplane_head = 512
  55. fpn_inplanes = [64, 128, 256, 512]
  56. else:
  57. fpn_dim = 256
  58. inplane_head = 2048
  59. fpn_inplanes = [256, 512, 1024, 2048]
  60. self.head = SFNetHead(
  61. inplane=inplane_head,
  62. num_class=num_classes,
  63. fpn_inplanes=fpn_inplanes,
  64. fpn_dim=fpn_dim,
  65. enable_auxiliary_loss=self.enable_auxiliary_loss)
  66. self.init_weight()
  67. def forward(self, x):
  68. feats = self.backbone(x)
  69. feats = [feats[i] for i in self.backbone_indices]
  70. logit_list = self.head(feats)
  71. logit_list = [
  72. F.interpolate(
  73. logit,
  74. x.shape[2:],
  75. mode='bilinear',
  76. align_corners=self.align_corners) for logit in logit_list
  77. ]
  78. return logit_list
  79. def init_weight(self):
  80. if self.pretrained is not None:
  81. utils.load_entire_model(self, self.pretrained)
  82. class SFNetHead(nn.Layer):
  83. """
  84. The SFNetHead implementation.
  85. Args:
  86. inplane (int): Input channels of PPM module.
  87. num_class (int): The unique number of target classes.
  88. fpn_inplanes (list): The feature channels from backbone.
  89. fpn_dim (int, optional): The input channels of FAM module. Default: 256.
  90. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: False.
  91. """
  92. def __init__(self,
  93. inplane,
  94. num_class,
  95. fpn_inplanes,
  96. fpn_dim=256,
  97. enable_auxiliary_loss=False):
  98. super(SFNetHead, self).__init__()
  99. self.ppm = layers.PPModule(
  100. in_channels=inplane,
  101. out_channels=fpn_dim,
  102. bin_sizes=(1, 2, 3, 6),
  103. dim_reduction=True,
  104. align_corners=True)
  105. self.enable_auxiliary_loss = enable_auxiliary_loss
  106. self.fpn_in = []
  107. for fpn_inplane in fpn_inplanes[:-1]:
  108. self.fpn_in.append(
  109. nn.Sequential(
  110. nn.Conv2D(fpn_inplane, fpn_dim, 1),
  111. layers.SyncBatchNorm(fpn_dim), nn.ReLU()))
  112. self.fpn_in = nn.LayerList(self.fpn_in)
  113. self.fpn_out = []
  114. self.fpn_out_align = []
  115. self.dsn = []
  116. for i in range(len(fpn_inplanes) - 1):
  117. self.fpn_out.append(
  118. nn.Sequential(
  119. layers.ConvBNReLU(
  120. fpn_dim, fpn_dim, 3, bias_attr=False)))
  121. self.fpn_out_align.append(
  122. AlignedModule(
  123. inplane=fpn_dim, outplane=fpn_dim // 2))
  124. if self.enable_auxiliary_loss:
  125. self.dsn.append(
  126. nn.Sequential(
  127. layers.AuxLayer(fpn_dim, fpn_dim, num_class)))
  128. self.fpn_out = nn.LayerList(self.fpn_out)
  129. self.fpn_out_align = nn.LayerList(self.fpn_out_align)
  130. if self.enable_auxiliary_loss:
  131. self.dsn = nn.LayerList(self.dsn)
  132. self.conv_last = nn.Sequential(
  133. layers.ConvBNReLU(
  134. len(fpn_inplanes) * fpn_dim, fpn_dim, 3, bias_attr=False),
  135. nn.Conv2D(
  136. fpn_dim, num_class, kernel_size=1))
  137. def forward(self, conv_out):
  138. psp_out = self.ppm(conv_out[-1])
  139. f = psp_out
  140. fpn_feature_list = [psp_out]
  141. out = []
  142. for i in reversed(range(len(conv_out) - 1)):
  143. conv_x = conv_out[i]
  144. conv_x = self.fpn_in[i](conv_x)
  145. f = self.fpn_out_align[i]([conv_x, f])
  146. f = conv_x + f
  147. fpn_feature_list.append(self.fpn_out[i](f))
  148. if self.enable_auxiliary_loss:
  149. out.append(self.dsn[i](f))
  150. fpn_feature_list.reverse()
  151. output_size = fpn_feature_list[0].shape[2:]
  152. fusion_list = [fpn_feature_list[0]]
  153. for i in range(1, len(fpn_feature_list)):
  154. fusion_list.append(
  155. F.interpolate(
  156. fpn_feature_list[i],
  157. output_size,
  158. mode='bilinear',
  159. align_corners=True))
  160. fusion_out = paddle.concat(fusion_list, 1)
  161. x = self.conv_last(fusion_out)
  162. if self.enable_auxiliary_loss:
  163. out.append(x)
  164. return out
  165. else:
  166. return [x]
  167. class AlignedModule(nn.Layer):
  168. """
  169. The FAM module implementation.
  170. Args:
  171. inplane (int): Input channles of FAM module.
  172. outplane (int): Output channels of FAN module.
  173. kernel_size (int, optional): Kernel size of semantic flow convolution layer. Default: 3.
  174. """
  175. def __init__(self, inplane, outplane, kernel_size=3):
  176. super(AlignedModule, self).__init__()
  177. self.down_h = nn.Conv2D(inplane, outplane, 1, bias_attr=False)
  178. self.down_l = nn.Conv2D(inplane, outplane, 1, bias_attr=False)
  179. self.flow_make = nn.Conv2D(
  180. outplane * 2,
  181. 2,
  182. kernel_size=kernel_size,
  183. padding=1,
  184. bias_attr=False)
  185. def flow_warp(self, inputs, flow, size):
  186. out_h, out_w = size
  187. n, c, h, w = inputs.shape
  188. norm = paddle.to_tensor([[[[out_w, out_h]]]]).astype('float32')
  189. h = paddle.linspace(-1.0, 1.0, out_h).reshape([-1, 1]).tile([1, out_w])
  190. w = paddle.linspace(-1.0, 1.0, out_w).tile([out_h, 1])
  191. grid = paddle.concat([paddle.unsqueeze(w, 2), paddle.unsqueeze(h, 2)],
  192. 2)
  193. grid = grid.tile([n, 1, 1, 1]).astype('float32')
  194. grid = grid + flow.transpose([0, 2, 3, 1]) / norm
  195. output = F.grid_sample(inputs, grid)
  196. return output
  197. def forward(self, x):
  198. low_feature, h_feature = x
  199. h_feature_orign = h_feature
  200. h, w = low_feature.shape[2:]
  201. size = (h, w)
  202. low_feature = self.down_l(low_feature)
  203. h_feature = self.down_h(h_feature)
  204. h_feature = F.interpolate(
  205. h_feature, size=size, mode='bilinear', align_corners=True)
  206. flow = self.flow_make(paddle.concat([h_feature, low_feature], 1))
  207. h_feature = self.flow_warp(h_feature_orign, flow, size=size)
  208. return h_feature