ocrnet.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg import utils
  18. from paddlex.paddleseg.cvlibs import manager, param_init
  19. from paddlex.paddleseg.models import layers
  20. @manager.MODELS.add_component
  21. class OCRNet(nn.Layer):
  22. """
  23. The OCRNet implementation based on PaddlePaddle.
  24. The original article refers to
  25. Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
  26. (https://arxiv.org/pdf/1909.11065.pdf)
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): Backbone network.
  30. backbone_indices (tuple): A tuple indicates the indices of output of backbone.
  31. It can be either one or two values, if two values, the first index will be taken as
  32. a deep-supervision feature in auxiliary layer; the second one will be taken as
  33. input of pixel representation. If one value, it is taken by both above.
  34. ocr_mid_channels (int, optional): The number of middle channels in OCRHead. Default: 512.
  35. ocr_key_channels (int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
  36. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  37. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  38. pretrained (str, optional): The path or url of pretrained model. Default: None.
  39. """
  40. def __init__(self,
  41. num_classes,
  42. backbone,
  43. backbone_indices,
  44. ocr_mid_channels=512,
  45. ocr_key_channels=256,
  46. align_corners=False,
  47. pretrained=None):
  48. super().__init__()
  49. self.backbone = backbone
  50. self.backbone_indices = backbone_indices
  51. in_channels = [
  52. self.backbone.feat_channels[i] for i in backbone_indices
  53. ]
  54. self.head = OCRHead(
  55. num_classes=num_classes,
  56. in_channels=in_channels,
  57. ocr_mid_channels=ocr_mid_channels,
  58. ocr_key_channels=ocr_key_channels)
  59. self.align_corners = align_corners
  60. self.pretrained = pretrained
  61. self.init_weight()
  62. def forward(self, x):
  63. feats = self.backbone(x)
  64. feats = [feats[i] for i in self.backbone_indices]
  65. logit_list = self.head(feats)
  66. if not self.training:
  67. logit_list = [logit_list[0]]
  68. logit_list = [
  69. F.interpolate(
  70. logit,
  71. paddle.shape(x)[2:],
  72. mode='bilinear',
  73. align_corners=self.align_corners) for logit in logit_list
  74. ]
  75. return logit_list
  76. def init_weight(self):
  77. if self.pretrained is not None:
  78. utils.load_entire_model(self, self.pretrained)
  79. class OCRHead(nn.Layer):
  80. """
  81. The Object contextual representation head.
  82. Args:
  83. num_classes(int): The unique number of target classes.
  84. in_channels(tuple): The number of input channels.
  85. ocr_mid_channels(int, optional): The number of middle channels in OCRHead. Default: 512.
  86. ocr_key_channels(int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
  87. """
  88. def __init__(self,
  89. num_classes,
  90. in_channels,
  91. ocr_mid_channels=512,
  92. ocr_key_channels=256):
  93. super().__init__()
  94. self.num_classes = num_classes
  95. self.spatial_gather = SpatialGatherBlock(ocr_mid_channels, num_classes)
  96. self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
  97. ocr_mid_channels)
  98. self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
  99. self.conv3x3_ocr = layers.ConvBNReLU(
  100. in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
  101. self.cls_head = nn.Conv2D(ocr_mid_channels, self.num_classes, 1)
  102. self.aux_head = nn.Sequential(
  103. layers.ConvBNReLU(in_channels[self.indices[0]],
  104. in_channels[self.indices[0]], 1),
  105. nn.Conv2D(in_channels[self.indices[0]], self.num_classes, 1))
  106. self.init_weight()
  107. def forward(self, feat_list):
  108. feat_shallow, feat_deep = feat_list[self.indices[0]], feat_list[
  109. self.indices[1]]
  110. soft_regions = self.aux_head(feat_shallow)
  111. pixels = self.conv3x3_ocr(feat_deep)
  112. object_regions = self.spatial_gather(pixels, soft_regions)
  113. ocr = self.spatial_ocr(pixels, object_regions)
  114. logit = self.cls_head(ocr)
  115. return [logit, soft_regions]
  116. def init_weight(self):
  117. """Initialize the parameters of model parts."""
  118. for sublayer in self.sublayers():
  119. if isinstance(sublayer, nn.Conv2D):
  120. param_init.normal_init(sublayer.weight, std=0.001)
  121. elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
  122. param_init.constant_init(sublayer.weight, value=1.0)
  123. param_init.constant_init(sublayer.bias, value=0.0)
  124. class SpatialGatherBlock(nn.Layer):
  125. """Aggregation layer to compute the pixel-region representation."""
  126. def __init__(self, pixels_channels, regions_channels):
  127. super().__init__()
  128. self.pixels_channels = pixels_channels
  129. self.regions_channels = regions_channels
  130. def forward(self, pixels, regions):
  131. # pixels: from (n, c, h, w) to (n, h*w, c)
  132. pixels = paddle.reshape(pixels, (0, self.pixels_channels, -1))
  133. pixels = paddle.transpose(pixels, (0, 2, 1))
  134. # regions: from (n, k, h, w) to (n, k, h*w)
  135. regions = paddle.reshape(regions, (0, self.regions_channels, -1))
  136. regions = F.softmax(regions, axis=2)
  137. # feats: from (n, k, c) to (n, c, k, 1)
  138. feats = paddle.bmm(regions, pixels)
  139. feats = paddle.transpose(feats, (0, 2, 1))
  140. feats = paddle.unsqueeze(feats, axis=-1)
  141. return feats
  142. class SpatialOCRModule(nn.Layer):
  143. """Aggregate the global object representation to update the representation for each pixel."""
  144. def __init__(self,
  145. in_channels,
  146. key_channels,
  147. out_channels,
  148. dropout_rate=0.1):
  149. super().__init__()
  150. self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
  151. self.conv1x1 = nn.Sequential(
  152. layers.ConvBNReLU(2 * in_channels, out_channels, 1),
  153. nn.Dropout2D(dropout_rate))
  154. def forward(self, pixels, regions):
  155. context = self.attention_block(pixels, regions)
  156. feats = paddle.concat([context, pixels], axis=1)
  157. feats = self.conv1x1(feats)
  158. return feats
  159. class ObjectAttentionBlock(nn.Layer):
  160. """A self-attention module."""
  161. def __init__(self, in_channels, key_channels):
  162. super().__init__()
  163. self.in_channels = in_channels
  164. self.key_channels = key_channels
  165. self.f_pixel = nn.Sequential(
  166. layers.ConvBNReLU(in_channels, key_channels, 1),
  167. layers.ConvBNReLU(key_channels, key_channels, 1))
  168. self.f_object = nn.Sequential(
  169. layers.ConvBNReLU(in_channels, key_channels, 1),
  170. layers.ConvBNReLU(key_channels, key_channels, 1))
  171. self.f_down = layers.ConvBNReLU(in_channels, key_channels, 1)
  172. self.f_up = layers.ConvBNReLU(key_channels, in_channels, 1)
  173. def forward(self, x, proxy):
  174. x_shape = paddle.shape(x)
  175. # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
  176. query = self.f_pixel(x)
  177. query = paddle.reshape(query, (0, self.key_channels, -1))
  178. query = paddle.transpose(query, (0, 2, 1))
  179. # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
  180. key = self.f_object(proxy)
  181. key = paddle.reshape(key, (0, self.key_channels, -1))
  182. # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
  183. value = self.f_down(proxy)
  184. value = paddle.reshape(value, (0, self.key_channels, -1))
  185. value = paddle.transpose(value, (0, 2, 1))
  186. # sim_map (n, h1*w1, h2*w2)
  187. sim_map = paddle.bmm(query, key)
  188. sim_map = (self.key_channels**-.5) * sim_map
  189. sim_map = F.softmax(sim_map, axis=-1)
  190. # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
  191. context = paddle.bmm(sim_map, value)
  192. context = paddle.transpose(context, (0, 2, 1))
  193. context = paddle.reshape(
  194. context, (0, self.key_channels, x_shape[2], x_shape[3]))
  195. context = self.f_up(context)
  196. return context