cascade_head.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from paddlex.ppdet.core.workspace import register
  19. from .bbox_head import BBoxHead, TwoFCHead, XConvNormHead
  20. from .roi_extractor import RoIAlign
  21. from ..shape_spec import ShapeSpec
  22. from ..bbox_utils import delta2bbox, clip_bbox, nonempty_bbox
  23. from ..cls_utils import _get_class_default_kwargs
  24. __all__ = ['CascadeTwoFCHead', 'CascadeXConvNormHead', 'CascadeHead']
  25. @register
  26. class CascadeTwoFCHead(nn.Layer):
  27. __shared__ = ['num_cascade_stage']
  28. """
  29. Cascade RCNN bbox head with Two fc layers to extract feature
  30. Args:
  31. in_channel (int): Input channel which can be derived by from_config
  32. out_channel (int): Output channel
  33. resolution (int): Resolution of input feature map, default 7
  34. num_cascade_stage (int): The number of cascade stage, default 3
  35. """
  36. def __init__(self,
  37. in_channel=256,
  38. out_channel=1024,
  39. resolution=7,
  40. num_cascade_stage=3):
  41. super(CascadeTwoFCHead, self).__init__()
  42. self.in_channel = in_channel
  43. self.out_channel = out_channel
  44. self.head_list = []
  45. for stage in range(num_cascade_stage):
  46. head_per_stage = self.add_sublayer(
  47. str(stage), TwoFCHead(in_channel, out_channel, resolution))
  48. self.head_list.append(head_per_stage)
  49. @classmethod
  50. def from_config(cls, cfg, input_shape):
  51. s = input_shape
  52. s = s[0] if isinstance(s, (list, tuple)) else s
  53. return {'in_channel': s.channels}
  54. @property
  55. def out_shape(self):
  56. return [ShapeSpec(channels=self.out_channel, )]
  57. def forward(self, rois_feat, stage=0):
  58. out = self.head_list[stage](rois_feat)
  59. return out
  60. @register
  61. class CascadeXConvNormHead(nn.Layer):
  62. __shared__ = ['norm_type', 'freeze_norm', 'num_cascade_stage']
  63. """
  64. Cascade RCNN bbox head with serveral convolution layers
  65. Args:
  66. in_channel (int): Input channels which can be derived by from_config
  67. num_convs (int): The number of conv layers
  68. conv_dim (int): The number of channels for the conv layers
  69. out_channel (int): Output channels
  70. resolution (int): Resolution of input feature map
  71. norm_type (string): Norm type, bn, gn, sync_bn are available,
  72. default `gn`
  73. freeze_norm (bool): Whether to freeze the norm
  74. num_cascade_stage (int): The number of cascade stage, default 3
  75. """
  76. def __init__(self,
  77. in_channel=256,
  78. num_convs=4,
  79. conv_dim=256,
  80. out_channel=1024,
  81. resolution=7,
  82. norm_type='gn',
  83. freeze_norm=False,
  84. num_cascade_stage=3):
  85. super(CascadeXConvNormHead, self).__init__()
  86. self.in_channel = in_channel
  87. self.out_channel = out_channel
  88. self.head_list = []
  89. for stage in range(num_cascade_stage):
  90. head_per_stage = self.add_sublayer(
  91. str(stage),
  92. XConvNormHead(
  93. in_channel,
  94. num_convs,
  95. conv_dim,
  96. out_channel,
  97. resolution,
  98. norm_type,
  99. freeze_norm,
  100. stage_name='stage{}_'.format(stage)))
  101. self.head_list.append(head_per_stage)
  102. @classmethod
  103. def from_config(cls, cfg, input_shape):
  104. s = input_shape
  105. s = s[0] if isinstance(s, (list, tuple)) else s
  106. return {'in_channel': s.channels}
  107. @property
  108. def out_shape(self):
  109. return [ShapeSpec(channels=self.out_channel, )]
  110. def forward(self, rois_feat, stage=0):
  111. out = self.head_list[stage](rois_feat)
  112. return out
  113. @register
  114. class CascadeHead(BBoxHead):
  115. __shared__ = ['num_classes', 'num_cascade_stages']
  116. __inject__ = ['bbox_assigner', 'bbox_loss']
  117. """
  118. Cascade RCNN bbox head
  119. Args:
  120. head (nn.Layer): Extract feature in bbox head
  121. in_channel (int): Input channel after RoI extractor
  122. roi_extractor (object): The module of RoI Extractor
  123. bbox_assigner (object): The module of Box Assigner, label and sample the
  124. box.
  125. num_classes (int): The number of classes
  126. bbox_weight (List[List[float]]): The weight to get the decode box and the
  127. length of weight is the number of cascade stage
  128. num_cascade_stages (int): THe number of stage to refine the box
  129. """
  130. def __init__(self,
  131. head,
  132. in_channel,
  133. roi_extractor=_get_class_default_kwargs(RoIAlign),
  134. bbox_assigner='BboxAssigner',
  135. num_classes=80,
  136. bbox_weight=[[10., 10., 5., 5.], [20.0, 20.0, 10.0, 10.0],
  137. [30.0, 30.0, 15.0, 15.0]],
  138. num_cascade_stages=3,
  139. bbox_loss=None):
  140. nn.Layer.__init__(self, )
  141. self.head = head
  142. self.roi_extractor = roi_extractor
  143. if isinstance(roi_extractor, dict):
  144. self.roi_extractor = RoIAlign(**roi_extractor)
  145. self.bbox_assigner = bbox_assigner
  146. self.num_classes = num_classes
  147. self.bbox_weight = bbox_weight
  148. self.num_cascade_stages = num_cascade_stages
  149. self.bbox_loss = bbox_loss
  150. self.bbox_score_list = []
  151. self.bbox_delta_list = []
  152. for i in range(num_cascade_stages):
  153. score_name = 'bbox_score_stage{}'.format(i)
  154. delta_name = 'bbox_delta_stage{}'.format(i)
  155. bbox_score = self.add_sublayer(
  156. score_name,
  157. nn.Linear(
  158. in_channel,
  159. self.num_classes + 1,
  160. weight_attr=paddle.ParamAttr(initializer=Normal(
  161. mean=0.0, std=0.01))))
  162. bbox_delta = self.add_sublayer(
  163. delta_name,
  164. nn.Linear(
  165. in_channel,
  166. 4,
  167. weight_attr=paddle.ParamAttr(initializer=Normal(
  168. mean=0.0, std=0.001))))
  169. self.bbox_score_list.append(bbox_score)
  170. self.bbox_delta_list.append(bbox_delta)
  171. self.assigned_label = None
  172. self.assigned_rois = None
  173. def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
  174. """
  175. body_feats (list[Tensor]): Feature maps from backbone
  176. rois (Tensor): RoIs generated from RPN module
  177. rois_num (Tensor): The number of RoIs in each image
  178. inputs (dict{Tensor}): The ground-truth of image
  179. """
  180. targets = []
  181. if self.training:
  182. rois, rois_num, targets = self.bbox_assigner(rois, rois_num,
  183. inputs)
  184. targets_list = [targets]
  185. self.assigned_rois = (rois, rois_num)
  186. self.assigned_targets = targets
  187. pred_bbox = None
  188. head_out_list = []
  189. for i in range(self.num_cascade_stages):
  190. if i > 0:
  191. rois, rois_num = self._get_rois_from_boxes(pred_bbox,
  192. inputs['im_shape'])
  193. if self.training:
  194. rois, rois_num, targets = self.bbox_assigner(
  195. rois, rois_num, inputs, i, is_cascade=True)
  196. targets_list.append(targets)
  197. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  198. bbox_feat = self.head(rois_feat, i)
  199. scores = self.bbox_score_list[i](bbox_feat)
  200. deltas = self.bbox_delta_list[i](bbox_feat)
  201. head_out_list.append([scores, deltas, rois])
  202. pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i])
  203. if self.training:
  204. loss = {}
  205. for stage, value in enumerate(zip(head_out_list, targets_list)):
  206. (scores, deltas, rois), targets = value
  207. loss_stage = self.get_loss(scores, deltas, targets, rois,
  208. self.bbox_weight[stage])
  209. for k, v in loss_stage.items():
  210. loss[k + "_stage{}".format(
  211. stage)] = v / self.num_cascade_stages
  212. return loss, bbox_feat
  213. else:
  214. scores, deltas, self.refined_rois = self.get_prediction(
  215. head_out_list)
  216. return (deltas, scores), self.head
  217. def _get_rois_from_boxes(self, boxes, im_shape):
  218. rois = []
  219. for i, boxes_per_image in enumerate(boxes):
  220. clip_box = clip_bbox(boxes_per_image, im_shape[i])
  221. if self.training:
  222. keep = nonempty_bbox(clip_box)
  223. if keep.shape[0] == 0:
  224. keep = paddle.zeros([1], dtype='int32')
  225. clip_box = paddle.gather(clip_box, keep)
  226. rois.append(clip_box)
  227. rois_num = paddle.concat([paddle.shape(r)[0] for r in rois])
  228. return rois, rois_num
  229. def _get_pred_bbox(self, deltas, proposals, weights):
  230. pred_proposals = paddle.concat(proposals) if len(
  231. proposals) > 1 else proposals[0]
  232. pred_bbox = delta2bbox(deltas, pred_proposals, weights)
  233. pred_bbox = paddle.reshape(pred_bbox, [-1, deltas.shape[-1]])
  234. num_prop = []
  235. for p in proposals:
  236. num_prop.append(p.shape[0])
  237. return pred_bbox.split(num_prop)
  238. def get_prediction(self, head_out_list):
  239. """
  240. head_out_list(List[Tensor]): scores, deltas, rois
  241. """
  242. pred_list = []
  243. scores_list = [F.softmax(head[0]) for head in head_out_list]
  244. scores = paddle.add_n(scores_list) / self.num_cascade_stages
  245. # Get deltas and rois from the last stage
  246. _, deltas, rois = head_out_list[-1]
  247. return scores, deltas, rois
  248. def get_refined_rois(self, ):
  249. return self.refined_rois