emanet.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class EMANet(nn.Layer):
  22. """
  23. Expectation Maximization Attention Networks for Semantic Segmentation based on PaddlePaddle.
  24. The original article refers to
  25. Xia Li, et al. "Expectation-Maximization Attention Networks for Semantic Segmentation"
  26. (https://arxiv.org/abs/1907.13426)
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): A backbone network.
  30. backbone_indices (tuple): The values in the tuple indicate the indices of output of backbone.
  31. ema_channels (int): EMA module channels.
  32. gc_channels (int): The input channels to Global Context Block.
  33. num_bases (int): Number of bases.
  34. stage_num (int): The iteration number for EM.
  35. momentum (float): The parameter for updating bases.
  36. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  37. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  38. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  39. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  40. pretrained (str, optional): The path or url of pretrained model. Default: None.
  41. """
  42. def __init__(self,
  43. num_classes,
  44. backbone,
  45. backbone_indices=(2, 3),
  46. ema_channels=512,
  47. gc_channels=256,
  48. num_bases=64,
  49. stage_num=3,
  50. momentum=0.1,
  51. concat_input=True,
  52. enable_auxiliary_loss=True,
  53. align_corners=False,
  54. pretrained=None):
  55. super().__init__()
  56. self.backbone = backbone
  57. self.backbone_indices = backbone_indices
  58. in_channels = [
  59. self.backbone.feat_channels[i] for i in backbone_indices
  60. ]
  61. self.head = EMAHead(num_classes, in_channels, ema_channels,
  62. gc_channels, num_bases, stage_num, momentum,
  63. concat_input, enable_auxiliary_loss)
  64. self.align_corners = align_corners
  65. self.pretrained = pretrained
  66. self.init_weight()
  67. def forward(self, x):
  68. feats = self.backbone(x)
  69. feats = [feats[i] for i in self.backbone_indices]
  70. logit_list = self.head(feats)
  71. logit_list = [
  72. F.interpolate(
  73. logit,
  74. paddle.shape(x)[2:],
  75. mode='bilinear',
  76. align_corners=self.align_corners) for logit in logit_list
  77. ]
  78. return logit_list
  79. def init_weight(self):
  80. if self.pretrained is not None:
  81. utils.load_entire_model(self, self.pretrained)
  82. class EMAHead(nn.Layer):
  83. """
  84. The EMANet head.
  85. Args:
  86. num_classes (int): The unique number of target classes.
  87. in_channels (tuple): The number of input channels.
  88. ema_channels (int): EMA module channels.
  89. gc_channels (int): The input channels to Global Context Block.
  90. num_bases (int): Number of bases.
  91. stage_num (int): The iteration number for EM.
  92. momentum (float): The parameter for updating bases.
  93. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  94. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  95. """
  96. def __init__(self,
  97. num_classes,
  98. in_channels,
  99. ema_channels,
  100. gc_channels,
  101. num_bases,
  102. stage_num,
  103. momentum,
  104. concat_input=True,
  105. enable_auxiliary_loss=True):
  106. super(EMAHead, self).__init__()
  107. self.in_channels = in_channels[-1]
  108. self.concat_input = concat_input
  109. self.enable_auxiliary_loss = enable_auxiliary_loss
  110. self.emau = EMAU(ema_channels, num_bases, stage_num, momentum=momentum)
  111. self.ema_in_conv = layers.ConvBNReLU(
  112. in_channels=self.in_channels,
  113. out_channels=ema_channels,
  114. kernel_size=3)
  115. self.ema_mid_conv = nn.Conv2D(
  116. ema_channels, ema_channels, kernel_size=1)
  117. self.ema_out_conv = layers.ConvBNReLU(
  118. in_channels=ema_channels, out_channels=ema_channels, kernel_size=1)
  119. self.bottleneck = layers.ConvBNReLU(
  120. in_channels=ema_channels, out_channels=gc_channels, kernel_size=3)
  121. self.cls = nn.Sequential(
  122. nn.Dropout2D(p=0.1), nn.Conv2D(gc_channels, num_classes, 1))
  123. self.aux = nn.Sequential(
  124. layers.ConvBNReLU(
  125. in_channels=1024, out_channels=256, kernel_size=3),
  126. nn.Dropout2D(p=0.1),
  127. nn.Conv2D(256, num_classes, 1))
  128. if self.concat_input:
  129. self.conv_cat = layers.ConvBNReLU(
  130. self.in_channels + gc_channels, gc_channels, kernel_size=3)
  131. def forward(self, feat_list):
  132. C3, C4 = feat_list
  133. feats = self.ema_in_conv(C4)
  134. identity = feats
  135. feats = self.ema_mid_conv(feats)
  136. recon = self.emau(feats)
  137. recon = F.relu(recon)
  138. recon = self.ema_out_conv(recon)
  139. output = F.relu(identity + recon)
  140. output = self.bottleneck(output)
  141. if self.concat_input:
  142. output = self.conv_cat(paddle.concat([C4, output], axis=1))
  143. output = self.cls(output)
  144. if self.enable_auxiliary_loss:
  145. auxout = self.aux(C3)
  146. return [output, auxout]
  147. else:
  148. return [output]
  149. class EMAU(nn.Layer):
  150. '''The Expectation-Maximization Attention Unit (EMAU).
  151. Arguments:
  152. c (int): The input and output channel number.
  153. k (int): The number of the bases.
  154. stage_num (int): The iteration number for EM.
  155. momentum (float): The parameter for updating bases.
  156. '''
  157. def __init__(self, c, k, stage_num=3, momentum=0.1):
  158. super(EMAU, self).__init__()
  159. assert stage_num >= 1
  160. self.stage_num = stage_num
  161. self.momentum = momentum
  162. self.c = c
  163. tmp_mu = self.create_parameter(
  164. shape=[1, c, k],
  165. default_initializer=paddle.nn.initializer.KaimingNormal(k))
  166. mu = F.normalize(paddle.to_tensor(tmp_mu), axis=1, p=2)
  167. self.register_buffer('mu', mu)
  168. def forward(self, x):
  169. x_shape = paddle.shape(x)
  170. x = x.flatten(2)
  171. mu = paddle.tile(self.mu, [x_shape[0], 1, 1])
  172. with paddle.no_grad():
  173. for i in range(self.stage_num):
  174. x_t = paddle.transpose(x, [0, 2, 1])
  175. z = paddle.bmm(x_t, mu)
  176. z = F.softmax(z, axis=2)
  177. z_ = F.normalize(z, axis=1, p=1)
  178. mu = paddle.bmm(x, z_)
  179. mu = F.normalize(mu, axis=1, p=2)
  180. z_t = paddle.transpose(z, [0, 2, 1])
  181. x = paddle.matmul(mu, z_t)
  182. x = paddle.reshape(x, [0, self.c, x_shape[2], x_shape[3]])
  183. if self.training:
  184. mu = paddle.mean(mu, 0, keepdim=True)
  185. mu = F.normalize(mu, axis=1, p=2)
  186. mu = self.mu * (1 - self.momentum) + mu * self.momentum
  187. if paddle.distributed.get_world_size() > 1:
  188. mu = paddle.distributed.all_reduce(mu)
  189. mu /= paddle.distributed.get_world_size()
  190. self.mu = mu
  191. return x