ttf_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.nn.initializer import Constant, Normal
  19. from paddle.regularizer import L2Decay
  20. from paddlex.ppdet.core.workspace import register
  21. from paddlex.ppdet.modeling.layers import DeformableConvV2, LiteConv
  22. import numpy as np
  23. @register
  24. class HMHead(nn.Layer):
  25. """
  26. Args:
  27. ch_in (int): The channel number of input Tensor.
  28. ch_out (int): The channel number of output Tensor.
  29. num_classes (int): Number of classes.
  30. conv_num (int): The convolution number of hm_feat.
  31. dcn_head(bool): whether use dcn in head. False by default.
  32. lite_head(bool): whether use lite version. False by default.
  33. norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
  34. bn by default
  35. Return:
  36. Heatmap head output
  37. """
  38. __shared__ = ['num_classes', 'norm_type']
  39. def __init__(
  40. self,
  41. ch_in,
  42. ch_out=128,
  43. num_classes=80,
  44. conv_num=2,
  45. dcn_head=False,
  46. lite_head=False,
  47. norm_type='bn', ):
  48. super(HMHead, self).__init__()
  49. head_conv = nn.Sequential()
  50. for i in range(conv_num):
  51. name = 'conv.{}'.format(i)
  52. if lite_head:
  53. lite_name = 'hm.' + name
  54. head_conv.add_sublayer(
  55. lite_name,
  56. LiteConv(
  57. in_channels=ch_in if i == 0 else ch_out,
  58. out_channels=ch_out,
  59. norm_type=norm_type))
  60. else:
  61. if dcn_head:
  62. head_conv.add_sublayer(
  63. name,
  64. DeformableConvV2(
  65. in_channels=ch_in if i == 0 else ch_out,
  66. out_channels=ch_out,
  67. kernel_size=3,
  68. weight_attr=ParamAttr(initializer=Normal(0,
  69. 0.01))))
  70. else:
  71. head_conv.add_sublayer(
  72. name,
  73. nn.Conv2D(
  74. in_channels=ch_in if i == 0 else ch_out,
  75. out_channels=ch_out,
  76. kernel_size=3,
  77. padding=1,
  78. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  79. bias_attr=ParamAttr(
  80. learning_rate=2., regularizer=L2Decay(0.))))
  81. head_conv.add_sublayer(name + '.act', nn.ReLU())
  82. self.feat = head_conv
  83. bias_init = float(-np.log((1 - 0.01) / 0.01))
  84. weight_attr = None if lite_head else ParamAttr(initializer=Normal(
  85. 0, 0.01))
  86. self.head = nn.Conv2D(
  87. in_channels=ch_out,
  88. out_channels=num_classes,
  89. kernel_size=1,
  90. weight_attr=weight_attr,
  91. bias_attr=ParamAttr(
  92. learning_rate=2.,
  93. regularizer=L2Decay(0.),
  94. initializer=Constant(bias_init)))
  95. def forward(self, feat):
  96. out = self.feat(feat)
  97. out = self.head(out)
  98. return out
  99. @register
  100. class WHHead(nn.Layer):
  101. """
  102. Args:
  103. ch_in (int): The channel number of input Tensor.
  104. ch_out (int): The channel number of output Tensor.
  105. conv_num (int): The convolution number of wh_feat.
  106. dcn_head(bool): whether use dcn in head. False by default.
  107. lite_head(bool): whether use lite version. False by default.
  108. norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
  109. bn by default
  110. Return:
  111. Width & Height head output
  112. """
  113. __shared__ = ['norm_type']
  114. def __init__(self,
  115. ch_in,
  116. ch_out=64,
  117. conv_num=2,
  118. dcn_head=False,
  119. lite_head=False,
  120. norm_type='bn'):
  121. super(WHHead, self).__init__()
  122. head_conv = nn.Sequential()
  123. for i in range(conv_num):
  124. name = 'conv.{}'.format(i)
  125. if lite_head:
  126. lite_name = 'wh.' + name
  127. head_conv.add_sublayer(
  128. lite_name,
  129. LiteConv(
  130. in_channels=ch_in if i == 0 else ch_out,
  131. out_channels=ch_out,
  132. norm_type=norm_type))
  133. else:
  134. if dcn_head:
  135. head_conv.add_sublayer(
  136. name,
  137. DeformableConvV2(
  138. in_channels=ch_in if i == 0 else ch_out,
  139. out_channels=ch_out,
  140. kernel_size=3,
  141. weight_attr=ParamAttr(initializer=Normal(0,
  142. 0.01))))
  143. else:
  144. head_conv.add_sublayer(
  145. name,
  146. nn.Conv2D(
  147. in_channels=ch_in if i == 0 else ch_out,
  148. out_channels=ch_out,
  149. kernel_size=3,
  150. padding=1,
  151. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  152. bias_attr=ParamAttr(
  153. learning_rate=2., regularizer=L2Decay(0.))))
  154. head_conv.add_sublayer(name + '.act', nn.ReLU())
  155. weight_attr = None if lite_head else ParamAttr(initializer=Normal(
  156. 0, 0.01))
  157. self.feat = head_conv
  158. self.head = nn.Conv2D(
  159. in_channels=ch_out,
  160. out_channels=4,
  161. kernel_size=1,
  162. weight_attr=weight_attr,
  163. bias_attr=ParamAttr(
  164. learning_rate=2., regularizer=L2Decay(0.)))
  165. def forward(self, feat):
  166. out = self.feat(feat)
  167. out = self.head(out)
  168. out = F.relu(out)
  169. return out
  170. @register
  171. class TTFHead(nn.Layer):
  172. """
  173. TTFHead
  174. Args:
  175. in_channels (int): the channel number of input to TTFHead.
  176. num_classes (int): the number of classes, 80 by default.
  177. hm_head_planes (int): the channel number in heatmap head,
  178. 128 by default.
  179. wh_head_planes (int): the channel number in width & height head,
  180. 64 by default.
  181. hm_head_conv_num (int): the number of convolution in heatmap head,
  182. 2 by default.
  183. wh_head_conv_num (int): the number of convolution in width & height
  184. head, 2 by default.
  185. hm_loss (object): Instance of 'CTFocalLoss'.
  186. wh_loss (object): Instance of 'GIoULoss'.
  187. wh_offset_base (float): the base offset of width and height,
  188. 16.0 by default.
  189. down_ratio (int): the actual down_ratio is calculated by base_down_ratio
  190. (default 16) and the number of upsample layers.
  191. lite_head(bool): whether use lite version. False by default.
  192. norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
  193. bn by default
  194. ags_module(bool): whether use AGS module to reweight location feature.
  195. false by default.
  196. """
  197. __shared__ = ['num_classes', 'down_ratio', 'norm_type']
  198. __inject__ = ['hm_loss', 'wh_loss']
  199. def __init__(self,
  200. in_channels,
  201. num_classes=80,
  202. hm_head_planes=128,
  203. wh_head_planes=64,
  204. hm_head_conv_num=2,
  205. wh_head_conv_num=2,
  206. hm_loss='CTFocalLoss',
  207. wh_loss='GIoULoss',
  208. wh_offset_base=16.,
  209. down_ratio=4,
  210. dcn_head=False,
  211. lite_head=False,
  212. norm_type='bn',
  213. ags_module=False):
  214. super(TTFHead, self).__init__()
  215. self.in_channels = in_channels
  216. self.hm_head = HMHead(in_channels, hm_head_planes, num_classes,
  217. hm_head_conv_num, dcn_head, lite_head, norm_type)
  218. self.wh_head = WHHead(in_channels, wh_head_planes, wh_head_conv_num,
  219. dcn_head, lite_head, norm_type)
  220. self.hm_loss = hm_loss
  221. self.wh_loss = wh_loss
  222. self.wh_offset_base = wh_offset_base
  223. self.down_ratio = down_ratio
  224. self.ags_module = ags_module
  225. @classmethod
  226. def from_config(cls, cfg, input_shape):
  227. if isinstance(input_shape, (list, tuple)):
  228. input_shape = input_shape[0]
  229. return {'in_channels': input_shape.channels, }
  230. def forward(self, feats):
  231. hm = self.hm_head(feats)
  232. wh = self.wh_head(feats) * self.wh_offset_base
  233. return hm, wh
  234. def filter_box_by_weight(self, pred, target, weight):
  235. """
  236. Filter out boxes where ttf_reg_weight is 0, only keep positive samples.
  237. """
  238. index = paddle.nonzero(weight > 0)
  239. index.stop_gradient = True
  240. weight = paddle.gather_nd(weight, index)
  241. pred = paddle.gather_nd(pred, index)
  242. target = paddle.gather_nd(target, index)
  243. return pred, target, weight
  244. def filter_loc_by_weight(self, score, weight):
  245. index = paddle.nonzero(weight > 0)
  246. index.stop_gradient = True
  247. score = paddle.gather_nd(score, index)
  248. return score
  249. def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight):
  250. pred_hm = paddle.clip(F.sigmoid(pred_hm), 1e-4, 1 - 1e-4)
  251. hm_loss = self.hm_loss(pred_hm, target_hm)
  252. H, W = target_hm.shape[2:]
  253. mask = paddle.reshape(target_weight, [-1, H, W])
  254. avg_factor = paddle.sum(mask) + 1e-4
  255. base_step = self.down_ratio
  256. shifts_x = paddle.arange(0, W * base_step, base_step, dtype='int32')
  257. shifts_y = paddle.arange(0, H * base_step, base_step, dtype='int32')
  258. shift_y, shift_x = paddle.tensor.meshgrid([shifts_y, shifts_x])
  259. base_loc = paddle.stack([shift_x, shift_y], axis=0)
  260. base_loc.stop_gradient = True
  261. pred_boxes = paddle.concat(
  262. [
  263. 0 - pred_wh[:, 0:2, :, :] + base_loc, pred_wh[:, 2:4] +
  264. base_loc
  265. ],
  266. axis=1)
  267. pred_boxes = paddle.transpose(pred_boxes, [0, 2, 3, 1])
  268. boxes = paddle.transpose(box_target, [0, 2, 3, 1])
  269. boxes.stop_gradient = True
  270. if self.ags_module:
  271. pred_hm_max = paddle.max(pred_hm, axis=1, keepdim=True)
  272. pred_hm_max_softmax = F.softmax(pred_hm_max, axis=1)
  273. pred_hm_max_softmax = paddle.transpose(pred_hm_max_softmax,
  274. [0, 2, 3, 1])
  275. pred_hm_max_softmax = self.filter_loc_by_weight(
  276. pred_hm_max_softmax, mask)
  277. else:
  278. pred_hm_max_softmax = None
  279. pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes,
  280. mask)
  281. mask.stop_gradient = True
  282. wh_loss = self.wh_loss(
  283. pred_boxes,
  284. boxes,
  285. iou_weight=mask.unsqueeze(1),
  286. loc_reweight=pred_hm_max_softmax)
  287. wh_loss = wh_loss / avg_factor
  288. ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss}
  289. return ttf_loss