rpn_head.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from paddle.regularizer import L2Decay
  19. from paddlex.ppdet.core.workspace import register
  20. from paddlex.ppdet.modeling import ops
  21. from .anchor_generator import AnchorGenerator
  22. from .target_layer import RPNTargetAssign
  23. from .proposal_generator import ProposalGenerator
  24. class RPNFeat(nn.Layer):
  25. """
  26. Feature extraction in RPN head
  27. Args:
  28. in_channel (int): Input channel
  29. out_channel (int): Output channel
  30. """
  31. def __init__(self, in_channel=1024, out_channel=1024):
  32. super(RPNFeat, self).__init__()
  33. # rpn feat is shared with each level
  34. self.rpn_conv = nn.Conv2D(
  35. in_channels=in_channel,
  36. out_channels=out_channel,
  37. kernel_size=3,
  38. padding=1,
  39. weight_attr=paddle.ParamAttr(initializer=Normal(
  40. mean=0., std=0.01)))
  41. self.rpn_conv.skip_quant = True
  42. def forward(self, feats):
  43. rpn_feats = []
  44. for feat in feats:
  45. rpn_feats.append(F.relu(self.rpn_conv(feat)))
  46. return rpn_feats
  47. @register
  48. class RPNHead(nn.Layer):
  49. """
  50. Region Proposal Network
  51. Args:
  52. anchor_generator (dict): configure of anchor generation
  53. rpn_target_assign (dict): configure of rpn targets assignment
  54. train_proposal (dict): configure of proposals generation
  55. at the stage of training
  56. test_proposal (dict): configure of proposals generation
  57. at the stage of prediction
  58. in_channel (int): channel of input feature maps which can be
  59. derived by from_config
  60. """
  61. def __init__(self,
  62. anchor_generator=AnchorGenerator().__dict__,
  63. rpn_target_assign=RPNTargetAssign().__dict__,
  64. train_proposal=ProposalGenerator(12000, 2000).__dict__,
  65. test_proposal=ProposalGenerator().__dict__,
  66. in_channel=1024):
  67. super(RPNHead, self).__init__()
  68. self.anchor_generator = anchor_generator
  69. self.rpn_target_assign = rpn_target_assign
  70. self.train_proposal = train_proposal
  71. self.test_proposal = test_proposal
  72. if isinstance(anchor_generator, dict):
  73. self.anchor_generator = AnchorGenerator(**anchor_generator)
  74. if isinstance(rpn_target_assign, dict):
  75. self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign)
  76. if isinstance(train_proposal, dict):
  77. self.train_proposal = ProposalGenerator(**train_proposal)
  78. if isinstance(test_proposal, dict):
  79. self.test_proposal = ProposalGenerator(**test_proposal)
  80. num_anchors = self.anchor_generator.num_anchors
  81. self.rpn_feat = RPNFeat(in_channel, in_channel)
  82. # rpn head is shared with each level
  83. # rpn roi classification scores
  84. self.rpn_rois_score = nn.Conv2D(
  85. in_channels=in_channel,
  86. out_channels=num_anchors,
  87. kernel_size=1,
  88. padding=0,
  89. weight_attr=paddle.ParamAttr(initializer=Normal(
  90. mean=0., std=0.01)))
  91. self.rpn_rois_score.skip_quant = True
  92. # rpn roi bbox regression deltas
  93. self.rpn_rois_delta = nn.Conv2D(
  94. in_channels=in_channel,
  95. out_channels=4 * num_anchors,
  96. kernel_size=1,
  97. padding=0,
  98. weight_attr=paddle.ParamAttr(initializer=Normal(
  99. mean=0., std=0.01)))
  100. self.rpn_rois_delta.skip_quant = True
  101. @classmethod
  102. def from_config(cls, cfg, input_shape):
  103. # FPN share same rpn head
  104. if isinstance(input_shape, (list, tuple)):
  105. input_shape = input_shape[0]
  106. return {'in_channel': input_shape.channels}
  107. def forward(self, feats, inputs):
  108. rpn_feats = self.rpn_feat(feats)
  109. scores = []
  110. deltas = []
  111. for rpn_feat in rpn_feats:
  112. rrs = self.rpn_rois_score(rpn_feat)
  113. rrd = self.rpn_rois_delta(rpn_feat)
  114. scores.append(rrs)
  115. deltas.append(rrd)
  116. anchors = self.anchor_generator(rpn_feats)
  117. # TODO: Fix batch_size > 1 when testing.
  118. if self.training:
  119. batch_size = inputs['im_shape'].shape[0]
  120. else:
  121. batch_size = 1
  122. rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs,
  123. batch_size)
  124. if self.training:
  125. loss = self.get_loss(scores, deltas, anchors, inputs)
  126. return rois, rois_num, loss
  127. else:
  128. return rois, rois_num, None
  129. def _gen_proposal(self, scores, bbox_deltas, anchors, inputs, batch_size):
  130. """
  131. scores (list[Tensor]): Multi-level scores prediction
  132. bbox_deltas (list[Tensor]): Multi-level deltas prediction
  133. anchors (list[Tensor]): Multi-level anchors
  134. inputs (dict): ground truth info
  135. """
  136. prop_gen = self.train_proposal if self.training else self.test_proposal
  137. im_shape = inputs['im_shape']
  138. # Collect multi-level proposals for each batch
  139. # Get 'topk' of them as final output
  140. bs_rois_collect = []
  141. bs_rois_num_collect = []
  142. # Generate proposals for each level and each batch.
  143. # Discard batch-computing to avoid sorting bbox cross different batches.
  144. for i in range(batch_size):
  145. rpn_rois_list = []
  146. rpn_prob_list = []
  147. rpn_rois_num_list = []
  148. for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
  149. anchors):
  150. rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
  151. scores=rpn_score[i:i + 1],
  152. bbox_deltas=rpn_delta[i:i + 1],
  153. anchors=anchor,
  154. im_shape=im_shape[i:i + 1])
  155. if rpn_rois.shape[0] > 0:
  156. rpn_rois_list.append(rpn_rois)
  157. rpn_prob_list.append(rpn_rois_prob)
  158. rpn_rois_num_list.append(rpn_rois_num)
  159. if len(scores) > 1:
  160. rpn_rois = paddle.concat(rpn_rois_list)
  161. rpn_prob = paddle.concat(rpn_prob_list).flatten()
  162. if rpn_prob.shape[0] > post_nms_top_n:
  163. topk_prob, topk_inds = paddle.topk(rpn_prob,
  164. post_nms_top_n)
  165. topk_rois = paddle.gather(rpn_rois, topk_inds)
  166. else:
  167. topk_rois = rpn_rois
  168. topk_prob = rpn_prob
  169. else:
  170. topk_rois = rpn_rois_list[0]
  171. topk_prob = rpn_prob_list[0].flatten()
  172. bs_rois_collect.append(topk_rois)
  173. bs_rois_num_collect.append(paddle.shape(topk_rois)[0])
  174. bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
  175. return bs_rois_collect, bs_rois_num_collect
  176. def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
  177. """
  178. pred_scores (list[Tensor]): Multi-level scores prediction
  179. pred_deltas (list[Tensor]): Multi-level deltas prediction
  180. anchors (list[Tensor]): Multi-level anchors
  181. inputs (dict): ground truth info, including im, gt_bbox, gt_score
  182. """
  183. anchors = [paddle.reshape(a, shape=(-1, 4)) for a in anchors]
  184. anchors = paddle.concat(anchors)
  185. scores = [
  186. paddle.reshape(
  187. paddle.transpose(
  188. v, perm=[0, 2, 3, 1]),
  189. shape=(v.shape[0], -1, 1)) for v in pred_scores
  190. ]
  191. scores = paddle.concat(scores, axis=1)
  192. deltas = [
  193. paddle.reshape(
  194. paddle.transpose(
  195. v, perm=[0, 2, 3, 1]),
  196. shape=(v.shape[0], -1, 4)) for v in pred_deltas
  197. ]
  198. deltas = paddle.concat(deltas, axis=1)
  199. score_tgt, bbox_tgt, loc_tgt, norm = self.rpn_target_assign(inputs,
  200. anchors)
  201. scores = paddle.reshape(x=scores, shape=(-1, ))
  202. deltas = paddle.reshape(x=deltas, shape=(-1, 4))
  203. score_tgt = paddle.concat(score_tgt)
  204. score_tgt.stop_gradient = True
  205. pos_mask = score_tgt == 1
  206. pos_ind = paddle.nonzero(pos_mask)
  207. valid_mask = score_tgt >= 0
  208. valid_ind = paddle.nonzero(valid_mask)
  209. # cls loss
  210. if valid_ind.shape[0] == 0:
  211. loss_rpn_cls = paddle.zeros([1], dtype='float32')
  212. else:
  213. score_pred = paddle.gather(scores, valid_ind)
  214. score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
  215. score_label.stop_gradient = True
  216. loss_rpn_cls = F.binary_cross_entropy_with_logits(
  217. logit=score_pred, label=score_label, reduction="sum")
  218. # reg loss
  219. if pos_ind.shape[0] == 0:
  220. loss_rpn_reg = paddle.zeros([1], dtype='float32')
  221. else:
  222. loc_pred = paddle.gather(deltas, pos_ind)
  223. loc_tgt = paddle.concat(loc_tgt)
  224. loc_tgt = paddle.gather(loc_tgt, pos_ind)
  225. loc_tgt.stop_gradient = True
  226. loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
  227. return {
  228. 'loss_rpn_cls': loss_rpn_cls / norm,
  229. 'loss_rpn_reg': loss_rpn_reg / norm
  230. }