nonlocal2d.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. class NonLocal2D(nn.Layer):
  19. """Basic Non-local module.
  20. This model is the implementation of "Non-local Neural Networks"
  21. (https://arxiv.org/abs/1711.07971)
  22. Args:
  23. in_channels (int): Channels of the input feature map.
  24. reduction (int): Channel reduction ratio. Default: 2.
  25. use_scale (bool): Whether to scale pairwise_weight by `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. Default: True.
  26. sub_sample (bool): Whether to utilize max pooling after pairwise function. Default: False.
  27. mode (str): Options are `gaussian`, `concatenation`, `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
  28. """
  29. def __init__(self,
  30. in_channels,
  31. reduction=2,
  32. use_scale=True,
  33. sub_sample=False,
  34. mode='embedded_gaussian'):
  35. super(NonLocal2D, self).__init__()
  36. self.in_channels = in_channels
  37. self.reduction = reduction
  38. self.use_scale = use_scale
  39. self.sub_sample = sub_sample
  40. self.mode = mode
  41. if mode not in [
  42. 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
  43. ]:
  44. raise ValueError(
  45. "Mode should be in 'gaussian', 'concatenation','embedded_gaussian' or 'dot_product'."
  46. )
  47. self.inter_channels = max(in_channels // reduction, 1)
  48. self.g = nn.Conv2D(
  49. in_channels=self.in_channels,
  50. out_channels=self.inter_channels,
  51. kernel_size=1)
  52. self.conv_out = layers.ConvBNReLU(
  53. in_channels=self.inter_channels,
  54. out_channels=self.in_channels,
  55. kernel_size=1,
  56. bias_attr=False)
  57. if self.mode != "gaussian":
  58. self.theta = nn.Conv2D(
  59. in_channels=self.in_channels,
  60. out_channels=self.inter_channels,
  61. kernel_size=1)
  62. self.phi = nn.Conv2D(
  63. in_channels=self.in_channels,
  64. out_channels=self.inter_channels,
  65. kernel_size=1)
  66. if self.mode == "concatenation":
  67. self.concat_project = layers.ConvBNReLU(
  68. in_channels=self.inter_channels * 2,
  69. out_channels=1,
  70. kernel_size=1,
  71. bias_attr=False)
  72. if self.sub_sample:
  73. max_pool_layer = nn.MaxPool2D(kernel_size=(2, 2))
  74. self.g = nn.Sequential(self.g, max_pool_layer)
  75. if self.mode != 'gaussian':
  76. self.phi = nn.Sequential(self.phi, max_pool_layer)
  77. else:
  78. self.phi = max_pool_layer
  79. def gaussian(self, theta_x, phi_x):
  80. pairwise_weight = paddle.matmul(theta_x, phi_x)
  81. pairwise_weight = F.softmax(pairwise_weight, axis=-1)
  82. return pairwise_weight
  83. def embedded_gaussian(self, theta_x, phi_x):
  84. pairwise_weight = paddle.matmul(theta_x, phi_x)
  85. if self.use_scale:
  86. pairwise_weight /= theta_x.shape[-1]**0.5
  87. pairwise_weight = F.softmax(pairwise_weight, -1)
  88. return pairwise_weight
  89. def dot_product(self, theta_x, phi_x):
  90. pairwise_weight = paddle.matmul(theta_x, phi_x)
  91. pairwise_weight /= pairwise_weight.shape[-1]
  92. return pairwise_weight
  93. def concatenation(self, theta_x, phi_x):
  94. h = theta_x.shape[2]
  95. w = phi_x.shape[3]
  96. theta_x = paddle.tile(theta_x, [1, 1, 1, w])
  97. phi_x = paddle.tile(phi_x, [1, 1, h, 1])
  98. concat_feature = paddle.concat([theta_x, phi_x], axis=1)
  99. pairwise_weight = self.concat_project(concat_feature)
  100. n, _, h, w = pairwise_weight.shape
  101. pairwise_weight = paddle.reshape(pairwise_weight, [n, h, w])
  102. pairwise_weight /= pairwise_weight.shape[-1]
  103. return pairwise_weight
  104. def forward(self, x):
  105. n, c, h, w = x.shape
  106. g_x = paddle.reshape(self.g(x), [n, self.inter_channels, -1])
  107. g_x = paddle.transpose(g_x, [0, 2, 1])
  108. if self.mode == 'gaussian':
  109. theta_x = paddle.reshape(x, [n, self.inter_channels, -1])
  110. theta_x = paddle.transpose(theta_x, [0, 2, 1])
  111. if self.sub_sample:
  112. phi_x = paddle.reshape(
  113. self.phi(x), [n, self.inter_channels, -1])
  114. else:
  115. phi_x = paddle.reshape(x, [n, self.in_channels, -1])
  116. elif self.mode == 'concatenation':
  117. theta_x = paddle.reshape(
  118. self.theta(x), [n, self.inter_channels, -1, 1])
  119. phi_x = paddle.reshape(
  120. self.phi(x), [n, self.inter_channels, 1, -1])
  121. else:
  122. theta_x = paddle.reshape(
  123. self.theta(x), [n, self.inter_channels, -1])
  124. theta_x = paddle.transpose(theta_x, [0, 2, 1])
  125. phi_x = paddle.reshape(self.phi(x), [n, self.inter_channels, -1])
  126. pairwise_func = getattr(self, self.mode)
  127. pairwise_weight = pairwise_func(theta_x, phi_x)
  128. y = paddle.matmul(pairwise_weight, g_x)
  129. y = paddle.transpose(y, [0, 2, 1])
  130. y = paddle.reshape(y, [n, self.inter_channels, h, w])
  131. output = x + self.conv_out(y)
  132. return output