s2anet_head.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/models/anchor_heads_rotated/s2anet_head.py
  16. import paddle
  17. from paddle import ParamAttr
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.nn.initializer import Normal, Constant
  21. from paddlex.ppdet.core.workspace import register
  22. from paddlex.ppdet.modeling import ops
  23. from paddlex.ppdet.modeling import bbox_utils
  24. from paddlex.ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
  25. from ..cls_utils import _get_class_default_kwargs
  26. import numpy as np
  27. class S2ANetAnchorGenerator(nn.Layer):
  28. """
  29. AnchorGenerator by paddle
  30. """
  31. def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
  32. super(S2ANetAnchorGenerator, self).__init__()
  33. self.base_size = base_size
  34. self.scales = paddle.to_tensor(scales)
  35. self.ratios = paddle.to_tensor(ratios)
  36. self.scale_major = scale_major
  37. self.ctr = ctr
  38. self.base_anchors = self.gen_base_anchors()
  39. @property
  40. def num_base_anchors(self):
  41. return self.base_anchors.shape[0]
  42. def gen_base_anchors(self):
  43. w = self.base_size
  44. h = self.base_size
  45. if self.ctr is None:
  46. x_ctr = 0.5 * (w - 1)
  47. y_ctr = 0.5 * (h - 1)
  48. else:
  49. x_ctr, y_ctr = self.ctr
  50. h_ratios = paddle.sqrt(self.ratios)
  51. w_ratios = 1 / h_ratios
  52. if self.scale_major:
  53. ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
  54. hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
  55. else:
  56. ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
  57. hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
  58. base_anchors = paddle.stack(
  59. [
  60. x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
  61. x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
  62. ],
  63. axis=-1)
  64. base_anchors = paddle.round(base_anchors)
  65. return base_anchors
  66. def _meshgrid(self, x, y, row_major=True):
  67. yy, xx = paddle.meshgrid(y, x)
  68. yy = yy.reshape([-1])
  69. xx = xx.reshape([-1])
  70. if row_major:
  71. return xx, yy
  72. else:
  73. return yy, xx
  74. def forward(self, featmap_size, stride=16):
  75. # featmap_size*stride project it to original area
  76. feat_h = featmap_size[0]
  77. feat_w = featmap_size[1]
  78. shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
  79. shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
  80. shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
  81. shifts = paddle.stack(
  82. [shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
  83. all_anchors = self.base_anchors[:, :] + shifts[:, :]
  84. all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
  85. return all_anchors
  86. def valid_flags(self, featmap_size, valid_size):
  87. feat_h, feat_w = featmap_size
  88. valid_h, valid_w = valid_size
  89. assert valid_h <= feat_h and valid_w <= feat_w
  90. valid_x = paddle.zeros([feat_w], dtype='int32')
  91. valid_y = paddle.zeros([feat_h], dtype='int32')
  92. valid_x[:valid_w] = 1
  93. valid_y[:valid_h] = 1
  94. valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
  95. valid = valid_xx & valid_yy
  96. valid = paddle.reshape(valid, [-1, 1])
  97. valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1])
  98. return valid
  99. class AlignConv(nn.Layer):
  100. def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
  101. super(AlignConv, self).__init__()
  102. self.kernel_size = kernel_size
  103. self.align_conv = paddle.vision.ops.DeformConv2D(
  104. in_channels,
  105. out_channels,
  106. kernel_size=self.kernel_size,
  107. padding=(self.kernel_size - 1) // 2,
  108. groups=groups,
  109. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  110. bias_attr=None)
  111. @paddle.no_grad()
  112. def get_offset(self, anchors, featmap_size, stride):
  113. """
  114. Args:
  115. anchors: [M,5] xc,yc,w,h,angle
  116. featmap_size: (feat_h, feat_w)
  117. stride: 8
  118. Returns:
  119. """
  120. anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
  121. dtype = anchors.dtype
  122. feat_h = featmap_size[0]
  123. feat_w = featmap_size[1]
  124. pad = (self.kernel_size - 1) // 2
  125. idx = paddle.arange(-pad, pad + 1, dtype=dtype)
  126. yy, xx = paddle.meshgrid(idx, idx)
  127. xx = paddle.reshape(xx, [-1])
  128. yy = paddle.reshape(yy, [-1])
  129. # get sampling locations of default conv
  130. xc = paddle.arange(0, feat_w, dtype=dtype)
  131. yc = paddle.arange(0, feat_h, dtype=dtype)
  132. yc, xc = paddle.meshgrid(yc, xc)
  133. xc = paddle.reshape(xc, [-1, 1])
  134. yc = paddle.reshape(yc, [-1, 1])
  135. x_conv = xc + xx
  136. y_conv = yc + yy
  137. # get sampling locations of anchors
  138. # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1)
  139. x_ctr = anchors[:, 0]
  140. y_ctr = anchors[:, 1]
  141. w = anchors[:, 2]
  142. h = anchors[:, 3]
  143. a = anchors[:, 4]
  144. x_ctr = paddle.reshape(x_ctr, [-1, 1])
  145. y_ctr = paddle.reshape(y_ctr, [-1, 1])
  146. w = paddle.reshape(w, [-1, 1])
  147. h = paddle.reshape(h, [-1, 1])
  148. a = paddle.reshape(a, [-1, 1])
  149. x_ctr = x_ctr / stride
  150. y_ctr = y_ctr / stride
  151. w_s = w / stride
  152. h_s = h / stride
  153. cos, sin = paddle.cos(a), paddle.sin(a)
  154. dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
  155. x, y = dw * xx, dh * yy
  156. xr = cos * x - sin * y
  157. yr = sin * x + cos * y
  158. x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
  159. # get offset filed
  160. offset_x = x_anchor - x_conv
  161. offset_y = y_anchor - y_conv
  162. offset = paddle.stack([offset_y, offset_x], axis=-1)
  163. offset = paddle.reshape(
  164. offset,
  165. [feat_h * feat_w, self.kernel_size * self.kernel_size * 2])
  166. offset = paddle.transpose(offset, [1, 0])
  167. offset = paddle.reshape(
  168. offset,
  169. [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w])
  170. return offset
  171. def forward(self, x, refine_anchors, featmap_size, stride):
  172. offset = self.get_offset(refine_anchors, featmap_size, stride)
  173. x = F.relu(self.align_conv(x, offset))
  174. return x
  175. @register
  176. class S2ANetHead(nn.Layer):
  177. """
  178. S2Anet head
  179. Args:
  180. stacked_convs (int): number of stacked_convs
  181. feat_in (int): input channels of feat
  182. feat_out (int): output channels of feat
  183. num_classes (int): num_classes
  184. anchor_strides (list): stride of anchors
  185. anchor_scales (list): scale of anchors
  186. anchor_ratios (list): ratios of anchors
  187. target_means (list): target_means
  188. target_stds (list): target_stds
  189. align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
  190. align_conv_size (int): kernel size of align_conv
  191. use_sigmoid_cls (bool): use sigmoid_cls or not
  192. reg_loss_weight (list): loss weight for regression
  193. """
  194. __shared__ = ['num_classes']
  195. __inject__ = ['anchor_assign']
  196. def __init__(self,
  197. stacked_convs=2,
  198. feat_in=256,
  199. feat_out=256,
  200. num_classes=15,
  201. anchor_strides=[8, 16, 32, 64, 128],
  202. anchor_scales=[4],
  203. anchor_ratios=[1.0],
  204. target_means=0.0,
  205. target_stds=1.0,
  206. align_conv_type='AlignConv',
  207. align_conv_size=3,
  208. use_sigmoid_cls=True,
  209. anchor_assign=_get_class_default_kwargs(RBoxAssigner),
  210. reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
  211. cls_loss_weight=[1.1, 1.05],
  212. reg_loss_type='l1'):
  213. super(S2ANetHead, self).__init__()
  214. self.stacked_convs = stacked_convs
  215. self.feat_in = feat_in
  216. self.feat_out = feat_out
  217. self.anchor_list = None
  218. self.anchor_scales = anchor_scales
  219. self.anchor_ratios = anchor_ratios
  220. self.anchor_strides = anchor_strides
  221. self.anchor_strides = paddle.to_tensor(anchor_strides)
  222. self.anchor_base_sizes = list(anchor_strides)
  223. self.means = paddle.ones(shape=[5]) * target_means
  224. self.stds = paddle.ones(shape=[5]) * target_stds
  225. assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
  226. self.align_conv_type = align_conv_type
  227. self.align_conv_size = align_conv_size
  228. self.use_sigmoid_cls = use_sigmoid_cls
  229. self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1
  230. self.sampling = False
  231. self.anchor_assign = anchor_assign
  232. self.reg_loss_weight = reg_loss_weight
  233. self.cls_loss_weight = cls_loss_weight
  234. self.alpha = 1.0
  235. self.beta = 1.0
  236. self.reg_loss_type = reg_loss_type
  237. self.s2anet_head_out = None
  238. # anchor
  239. self.anchor_generators = []
  240. for anchor_base in self.anchor_base_sizes:
  241. self.anchor_generators.append(
  242. S2ANetAnchorGenerator(anchor_base, anchor_scales,
  243. anchor_ratios))
  244. self.anchor_generators = nn.LayerList(self.anchor_generators)
  245. self.fam_cls_convs = nn.Sequential()
  246. self.fam_reg_convs = nn.Sequential()
  247. for i in range(self.stacked_convs):
  248. chan_in = self.feat_in if i == 0 else self.feat_out
  249. self.fam_cls_convs.add_sublayer(
  250. 'fam_cls_conv_{}'.format(i),
  251. nn.Conv2D(
  252. in_channels=chan_in,
  253. out_channels=self.feat_out,
  254. kernel_size=3,
  255. padding=1,
  256. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  257. bias_attr=ParamAttr(initializer=Constant(0))))
  258. self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
  259. nn.ReLU())
  260. self.fam_reg_convs.add_sublayer(
  261. 'fam_reg_conv_{}'.format(i),
  262. nn.Conv2D(
  263. in_channels=chan_in,
  264. out_channels=self.feat_out,
  265. kernel_size=3,
  266. padding=1,
  267. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  268. bias_attr=ParamAttr(initializer=Constant(0))))
  269. self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
  270. nn.ReLU())
  271. self.fam_reg = nn.Conv2D(
  272. self.feat_out,
  273. 5,
  274. 1,
  275. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  276. bias_attr=ParamAttr(initializer=Constant(0)))
  277. prior_prob = 0.01
  278. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  279. self.fam_cls = nn.Conv2D(
  280. self.feat_out,
  281. self.cls_out_channels,
  282. 1,
  283. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  284. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  285. if self.align_conv_type == "AlignConv":
  286. self.align_conv = AlignConv(self.feat_out, self.feat_out,
  287. self.align_conv_size)
  288. elif self.align_conv_type == "Conv":
  289. self.align_conv = nn.Conv2D(
  290. self.feat_out,
  291. self.feat_out,
  292. self.align_conv_size,
  293. padding=(self.align_conv_size - 1) // 2,
  294. bias_attr=ParamAttr(initializer=Constant(0)))
  295. elif self.align_conv_type == "DCN":
  296. self.align_conv_offset = nn.Conv2D(
  297. self.feat_out,
  298. 2 * self.align_conv_size**2,
  299. 1,
  300. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  301. bias_attr=ParamAttr(initializer=Constant(0)))
  302. self.align_conv = paddle.vision.ops.DeformConv2D(
  303. self.feat_out,
  304. self.feat_out,
  305. self.align_conv_size,
  306. padding=(self.align_conv_size - 1) // 2,
  307. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  308. bias_attr=False)
  309. self.or_conv = nn.Conv2D(
  310. self.feat_out,
  311. self.feat_out,
  312. kernel_size=3,
  313. padding=1,
  314. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  315. bias_attr=ParamAttr(initializer=Constant(0)))
  316. # ODM
  317. self.odm_cls_convs = nn.Sequential()
  318. self.odm_reg_convs = nn.Sequential()
  319. for i in range(self.stacked_convs):
  320. ch_in = self.feat_out
  321. # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
  322. self.odm_cls_convs.add_sublayer(
  323. 'odm_cls_conv_{}'.format(i),
  324. nn.Conv2D(
  325. in_channels=ch_in,
  326. out_channels=self.feat_out,
  327. kernel_size=3,
  328. stride=1,
  329. padding=1,
  330. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  331. bias_attr=ParamAttr(initializer=Constant(0))))
  332. self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
  333. nn.ReLU())
  334. self.odm_reg_convs.add_sublayer(
  335. 'odm_reg_conv_{}'.format(i),
  336. nn.Conv2D(
  337. in_channels=self.feat_out,
  338. out_channels=self.feat_out,
  339. kernel_size=3,
  340. stride=1,
  341. padding=1,
  342. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  343. bias_attr=ParamAttr(initializer=Constant(0))))
  344. self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
  345. nn.ReLU())
  346. self.odm_cls = nn.Conv2D(
  347. self.feat_out,
  348. self.cls_out_channels,
  349. 3,
  350. padding=1,
  351. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  352. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  353. self.odm_reg = nn.Conv2D(
  354. self.feat_out,
  355. 5,
  356. 3,
  357. padding=1,
  358. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  359. bias_attr=ParamAttr(initializer=Constant(0)))
  360. self.featmap_sizes = []
  361. self.base_anchors_list = []
  362. self.refine_anchor_list = []
  363. def forward(self, feats):
  364. fam_reg_branch_list = []
  365. fam_cls_branch_list = []
  366. odm_reg_branch_list = []
  367. odm_cls_branch_list = []
  368. self.featmap_sizes_list = []
  369. self.base_anchors_list = []
  370. self.refine_anchor_list = []
  371. for feat_idx in range(len(feats)):
  372. feat = feats[feat_idx]
  373. fam_cls_feat = self.fam_cls_convs(feat)
  374. fam_cls = self.fam_cls(fam_cls_feat)
  375. # [N, CLS, H, W] --> [N, H, W, CLS]
  376. fam_cls = fam_cls.transpose([0, 2, 3, 1])
  377. fam_cls_reshape = paddle.reshape(
  378. fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels])
  379. fam_cls_branch_list.append(fam_cls_reshape)
  380. fam_reg_feat = self.fam_reg_convs(feat)
  381. fam_reg = self.fam_reg(fam_reg_feat)
  382. # [N, 5, H, W] --> [N, H, W, 5]
  383. fam_reg = fam_reg.transpose([0, 2, 3, 1])
  384. fam_reg_reshape = paddle.reshape(fam_reg,
  385. [fam_reg.shape[0], -1, 5])
  386. fam_reg_branch_list.append(fam_reg_reshape)
  387. # prepare anchor
  388. featmap_size = (paddle.shape(feat)[2], paddle.shape(feat)[3])
  389. self.featmap_sizes_list.append(featmap_size)
  390. init_anchors = self.anchor_generators[feat_idx](
  391. featmap_size, self.anchor_strides[feat_idx])
  392. init_anchors = paddle.to_tensor(init_anchors, dtype='float32')
  393. NA = featmap_size[0] * featmap_size[1]
  394. init_anchors = paddle.reshape(init_anchors, [NA, 4])
  395. init_anchors = self.rect2rbox(init_anchors)
  396. self.base_anchors_list.append(init_anchors)
  397. if self.training:
  398. refine_anchor = self.bbox_decode(fam_reg.detach(),
  399. init_anchors)
  400. else:
  401. refine_anchor = self.bbox_decode(fam_reg, init_anchors)
  402. self.refine_anchor_list.append(refine_anchor)
  403. if self.align_conv_type == 'AlignConv':
  404. align_feat = self.align_conv(feat,
  405. refine_anchor.clone(),
  406. featmap_size,
  407. self.anchor_strides[feat_idx])
  408. elif self.align_conv_type == 'DCN':
  409. align_offset = self.align_conv_offset(feat)
  410. align_feat = self.align_conv(feat, align_offset)
  411. elif self.align_conv_type == 'Conv':
  412. align_feat = self.align_conv(feat)
  413. or_feat = self.or_conv(align_feat)
  414. odm_reg_feat = or_feat
  415. odm_cls_feat = or_feat
  416. odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
  417. odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
  418. odm_cls_score = self.odm_cls(odm_cls_feat)
  419. # [N, CLS, H, W] --> [N, H, W, CLS]
  420. odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
  421. odm_cls_score_shape = odm_cls_score.shape
  422. odm_cls_score_reshape = paddle.reshape(odm_cls_score, [
  423. odm_cls_score_shape[0], odm_cls_score_shape[1] *
  424. odm_cls_score_shape[2], self.cls_out_channels
  425. ])
  426. odm_cls_branch_list.append(odm_cls_score_reshape)
  427. odm_bbox_pred = self.odm_reg(odm_reg_feat)
  428. # [N, 5, H, W] --> [N, H, W, 5]
  429. odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
  430. odm_bbox_pred_reshape = paddle.reshape(odm_bbox_pred, [-1, 5])
  431. odm_bbox_pred_reshape = paddle.unsqueeze(
  432. odm_bbox_pred_reshape, axis=0)
  433. odm_reg_branch_list.append(odm_bbox_pred_reshape)
  434. self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
  435. odm_cls_branch_list, odm_reg_branch_list)
  436. return self.s2anet_head_out
  437. def get_prediction(self, nms_pre=2000):
  438. refine_anchors = self.refine_anchor_list
  439. fam_cls_branch_list = self.s2anet_head_out[0]
  440. fam_reg_branch_list = self.s2anet_head_out[1]
  441. odm_cls_branch_list = self.s2anet_head_out[2]
  442. odm_reg_branch_list = self.s2anet_head_out[3]
  443. pred_scores, pred_bboxes = self.get_bboxes(
  444. odm_cls_branch_list, odm_reg_branch_list, refine_anchors, nms_pre,
  445. self.cls_out_channels, self.use_sigmoid_cls)
  446. return pred_scores, pred_bboxes
  447. def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
  448. """
  449. Args:
  450. pred: pred score
  451. label: label
  452. delta: delta
  453. Returns: loss
  454. """
  455. assert pred.shape == label.shape and label.numel() > 0
  456. assert delta > 0
  457. diff = paddle.abs(pred - label)
  458. loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
  459. diff - 0.5 * delta)
  460. return loss
  461. def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'):
  462. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  463. pos_inds, neg_inds) = fam_target
  464. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  465. fam_cls_losses = []
  466. fam_bbox_losses = []
  467. st_idx = 0
  468. num_total_samples = len(pos_inds) + len(
  469. neg_inds) if self.sampling else len(pos_inds)
  470. num_total_samples = max(1, num_total_samples)
  471. for idx, feat_size in enumerate(self.featmap_sizes_list):
  472. feat_anchor_num = feat_size[0] * feat_size[1]
  473. # step1: get data
  474. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  475. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  476. feat_bbox_targets = bbox_targets[st_idx:st_idx +
  477. feat_anchor_num, :]
  478. feat_bbox_weights = bbox_weights[st_idx:st_idx +
  479. feat_anchor_num, :]
  480. # step2: calc cls loss
  481. feat_labels = feat_labels.reshape(-1)
  482. feat_label_weights = feat_label_weights.reshape(-1)
  483. fam_cls_score = fam_cls_branch_list[idx]
  484. fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
  485. fam_cls_score1 = fam_cls_score
  486. feat_labels = paddle.to_tensor(feat_labels)
  487. feat_labels_one_hot = paddle.nn.functional.one_hot(
  488. feat_labels, self.cls_out_channels + 1)
  489. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  490. feat_labels_one_hot.stop_gradient = True
  491. num_total_samples = paddle.to_tensor(
  492. num_total_samples, dtype='float32', stop_gradient=True)
  493. fam_cls = F.sigmoid_focal_loss(
  494. fam_cls_score1,
  495. feat_labels_one_hot,
  496. normalizer=num_total_samples,
  497. reduction='none')
  498. feat_label_weights = feat_label_weights.reshape(
  499. feat_label_weights.shape[0], 1)
  500. feat_label_weights = np.repeat(
  501. feat_label_weights, self.cls_out_channels, axis=1)
  502. feat_label_weights = paddle.to_tensor(
  503. feat_label_weights, stop_gradient=True)
  504. fam_cls = fam_cls * feat_label_weights
  505. fam_cls_total = paddle.sum(fam_cls)
  506. fam_cls_losses.append(fam_cls_total)
  507. # step3: regression loss
  508. feat_bbox_targets = paddle.to_tensor(
  509. feat_bbox_targets, dtype='float32', stop_gradient=True)
  510. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  511. fam_bbox_pred = fam_reg_branch_list[idx]
  512. fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
  513. fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
  514. fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
  515. loss_weight = paddle.to_tensor(
  516. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  517. fam_bbox = paddle.multiply(fam_bbox, loss_weight)
  518. feat_bbox_weights = paddle.to_tensor(
  519. feat_bbox_weights, stop_gradient=True)
  520. if reg_loss_type == 'l1':
  521. fam_bbox = fam_bbox * feat_bbox_weights
  522. fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
  523. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  524. fam_bbox = paddle.sum(fam_bbox, axis=-1)
  525. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  526. try:
  527. from rbox_iou_ops import rbox_iou
  528. except Exception as e:
  529. print("import custom_ops error, try install rbox_iou_ops " \
  530. "following ppdet/ext_op/README.md", e)
  531. sys.stdout.flush()
  532. sys.exit(-1)
  533. # calc iou
  534. fam_bbox_decode = self.delta2rbox(self.base_anchors_list[idx],
  535. fam_bbox_pred)
  536. bbox_gt_bboxes = paddle.to_tensor(
  537. bbox_gt_bboxes,
  538. dtype=fam_bbox_decode.dtype,
  539. place=fam_bbox_decode.place)
  540. bbox_gt_bboxes.stop_gradient = True
  541. iou = rbox_iou(fam_bbox_decode, bbox_gt_bboxes)
  542. iou = paddle.diag(iou)
  543. if reg_loss_type == 'gwd':
  544. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  545. feat_anchor_num, :]
  546. fam_bbox_total = self.gwd_loss(fam_bbox_decode,
  547. bbox_gt_bboxes_level)
  548. fam_bbox_total = fam_bbox_total * feat_bbox_weights
  549. fam_bbox_total = paddle.sum(
  550. fam_bbox_total) / num_total_samples
  551. fam_bbox_losses.append(fam_bbox_total)
  552. st_idx += feat_anchor_num
  553. fam_cls_loss = paddle.add_n(fam_cls_losses)
  554. fam_cls_loss_weight = paddle.to_tensor(
  555. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  556. fam_cls_loss = fam_cls_loss * fam_cls_loss_weight
  557. fam_reg_loss = paddle.add_n(fam_bbox_losses)
  558. return fam_cls_loss, fam_reg_loss
  559. def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
  560. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  561. pos_inds, neg_inds) = odm_target
  562. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  563. odm_cls_losses = []
  564. odm_bbox_losses = []
  565. st_idx = 0
  566. num_total_samples = len(pos_inds) + len(
  567. neg_inds) if self.sampling else len(pos_inds)
  568. num_total_samples = max(1, num_total_samples)
  569. for idx, feat_size in enumerate(self.featmap_sizes_list):
  570. feat_anchor_num = feat_size[0] * feat_size[1]
  571. # step1: get data
  572. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  573. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  574. feat_bbox_targets = bbox_targets[st_idx:st_idx +
  575. feat_anchor_num, :]
  576. feat_bbox_weights = bbox_weights[st_idx:st_idx +
  577. feat_anchor_num, :]
  578. # step2: calc cls loss
  579. feat_labels = feat_labels.reshape(-1)
  580. feat_label_weights = feat_label_weights.reshape(-1)
  581. odm_cls_score = odm_cls_branch_list[idx]
  582. odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
  583. odm_cls_score1 = odm_cls_score
  584. feat_labels = paddle.to_tensor(feat_labels)
  585. feat_labels_one_hot = paddle.nn.functional.one_hot(
  586. feat_labels, self.cls_out_channels + 1)
  587. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  588. feat_labels_one_hot.stop_gradient = True
  589. num_total_samples = paddle.to_tensor(
  590. num_total_samples, dtype='float32', stop_gradient=True)
  591. odm_cls = F.sigmoid_focal_loss(
  592. odm_cls_score1,
  593. feat_labels_one_hot,
  594. normalizer=num_total_samples,
  595. reduction='none')
  596. feat_label_weights = feat_label_weights.reshape(
  597. feat_label_weights.shape[0], 1)
  598. feat_label_weights = np.repeat(
  599. feat_label_weights, self.cls_out_channels, axis=1)
  600. feat_label_weights = paddle.to_tensor(feat_label_weights)
  601. feat_label_weights.stop_gradient = True
  602. odm_cls = odm_cls * feat_label_weights
  603. odm_cls_total = paddle.sum(odm_cls)
  604. odm_cls_losses.append(odm_cls_total)
  605. # # step3: regression loss
  606. feat_bbox_targets = paddle.to_tensor(
  607. feat_bbox_targets, dtype='float32')
  608. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  609. feat_bbox_targets.stop_gradient = True
  610. odm_bbox_pred = odm_reg_branch_list[idx]
  611. odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
  612. odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
  613. odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
  614. loss_weight = paddle.to_tensor(
  615. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  616. odm_bbox = paddle.multiply(odm_bbox, loss_weight)
  617. feat_bbox_weights = paddle.to_tensor(
  618. feat_bbox_weights, stop_gradient=True)
  619. if reg_loss_type == 'l1':
  620. odm_bbox = odm_bbox * feat_bbox_weights
  621. odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
  622. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  623. odm_bbox = paddle.sum(odm_bbox, axis=-1)
  624. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  625. try:
  626. from rbox_iou_ops import rbox_iou
  627. except Exception as e:
  628. print("import custom_ops error, try install rbox_iou_ops " \
  629. "following ppdet/ext_op/README.md", e)
  630. sys.stdout.flush()
  631. sys.exit(-1)
  632. # calc iou
  633. odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx],
  634. odm_bbox_pred)
  635. bbox_gt_bboxes = paddle.to_tensor(
  636. bbox_gt_bboxes,
  637. dtype=odm_bbox_decode.dtype,
  638. place=odm_bbox_decode.place)
  639. bbox_gt_bboxes.stop_gradient = True
  640. iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes)
  641. iou = paddle.diag(iou)
  642. if reg_loss_type == 'gwd':
  643. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  644. feat_anchor_num, :]
  645. odm_bbox_total = self.gwd_loss(odm_bbox_decode,
  646. bbox_gt_bboxes_level)
  647. odm_bbox_total = odm_bbox_total * feat_bbox_weights
  648. odm_bbox_total = paddle.sum(
  649. odm_bbox_total) / num_total_samples
  650. odm_bbox_losses.append(odm_bbox_total)
  651. st_idx += feat_anchor_num
  652. odm_cls_loss = paddle.add_n(odm_cls_losses)
  653. odm_cls_loss_weight = paddle.to_tensor(
  654. self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
  655. odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
  656. odm_reg_loss = paddle.add_n(odm_bbox_losses)
  657. return odm_cls_loss, odm_reg_loss
  658. def get_loss(self, inputs):
  659. # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd
  660. # compute loss
  661. fam_cls_loss_lst = []
  662. fam_reg_loss_lst = []
  663. odm_cls_loss_lst = []
  664. odm_reg_loss_lst = []
  665. im_shape = inputs['im_shape']
  666. for im_id in range(im_shape.shape[0]):
  667. np_im_shape = inputs['im_shape'][im_id].numpy()
  668. np_scale_factor = inputs['scale_factor'][im_id].numpy()
  669. # data_format: (xc, yc, w, h, theta)
  670. gt_bboxes = inputs['gt_rbox'][im_id].numpy()
  671. gt_labels = inputs['gt_class'][im_id].numpy()
  672. is_crowd = inputs['is_crowd'][im_id].numpy()
  673. gt_labels = gt_labels + 1
  674. # featmap_sizes
  675. anchors_list_all = np.concatenate(self.base_anchors_list)
  676. # get im_feat
  677. fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
  678. fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
  679. odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
  680. odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
  681. im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
  682. odm_cls_feats_list, odm_reg_feats_list)
  683. # FAM
  684. im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes,
  685. gt_labels, is_crowd)
  686. if im_fam_target is not None:
  687. im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
  688. im_fam_target, im_s2anet_head_out, self.reg_loss_type)
  689. fam_cls_loss_lst.append(im_fam_cls_loss)
  690. fam_reg_loss_lst.append(im_fam_reg_loss)
  691. # ODM
  692. np_refine_anchors_list = paddle.concat(
  693. self.refine_anchor_list).numpy()
  694. np_refine_anchors_list = np.concatenate(np_refine_anchors_list)
  695. np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5)
  696. im_odm_target = self.anchor_assign(np_refine_anchors_list,
  697. gt_bboxes, gt_labels, is_crowd)
  698. if im_odm_target is not None:
  699. im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
  700. im_odm_target, im_s2anet_head_out, self.reg_loss_type)
  701. odm_cls_loss_lst.append(im_odm_cls_loss)
  702. odm_reg_loss_lst.append(im_odm_reg_loss)
  703. fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
  704. fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
  705. odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
  706. odm_reg_loss = paddle.add_n(odm_reg_loss_lst)
  707. return {
  708. 'fam_cls_loss': fam_cls_loss,
  709. 'fam_reg_loss': fam_reg_loss,
  710. 'odm_cls_loss': odm_cls_loss,
  711. 'odm_reg_loss': odm_reg_loss
  712. }
  713. def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
  714. cls_out_channels, use_sigmoid_cls):
  715. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
  716. mlvl_bboxes = []
  717. mlvl_scores = []
  718. idx = 0
  719. for cls_score, bbox_pred, anchors in zip(cls_score_list,
  720. bbox_pred_list, mlvl_anchors):
  721. cls_score = paddle.reshape(cls_score, [-1, cls_out_channels])
  722. if use_sigmoid_cls:
  723. scores = F.sigmoid(cls_score)
  724. else:
  725. scores = F.softmax(cls_score, axis=-1)
  726. # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
  727. bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0])
  728. bbox_pred = paddle.reshape(bbox_pred, [-1, 5])
  729. anchors = paddle.reshape(anchors, [-1, 5])
  730. if scores.shape[0] > nms_pre:
  731. # Get maximum scores for foreground classes.
  732. if use_sigmoid_cls:
  733. max_scores = paddle.max(scores, axis=1)
  734. else:
  735. max_scores = paddle.max(scores[:, 1:], axis=1)
  736. topk_val, topk_inds = paddle.topk(max_scores, nms_pre)
  737. anchors = paddle.gather(anchors, topk_inds)
  738. bbox_pred = paddle.gather(bbox_pred, topk_inds)
  739. scores = paddle.gather(scores, topk_inds)
  740. bbox_delta = paddle.reshape(bbox_pred, [-1, 5])
  741. bboxes = self.delta2rbox(anchors, bbox_delta)
  742. mlvl_bboxes.append(bboxes)
  743. mlvl_scores.append(scores)
  744. idx += 1
  745. mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
  746. mlvl_scores = paddle.concat(mlvl_scores)
  747. return mlvl_scores, mlvl_bboxes
  748. def rect2rbox(self, bboxes):
  749. """
  750. :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
  751. :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
  752. """
  753. bboxes = paddle.reshape(bboxes, [-1, 4])
  754. num_boxes = paddle.shape(bboxes)[0]
  755. x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
  756. y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
  757. edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
  758. edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
  759. rbox_w = paddle.maximum(edges1, edges2)
  760. rbox_h = paddle.minimum(edges1, edges2)
  761. # set angle
  762. inds = edges1 < edges2
  763. inds = paddle.cast(inds, 'int32')
  764. rboxes_angle = inds * np.pi / 2.0
  765. rboxes = paddle.stack(
  766. (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
  767. return rboxes
  768. # deltas to rbox
  769. def delta2rbox(self, rrois, deltas, wh_ratio_clip=1e-6):
  770. """
  771. :param rrois: (cx, cy, w, h, theta)
  772. :param deltas: (dx, dy, dw, dh, dtheta)
  773. :param means: means of anchor
  774. :param stds: stds of anchor
  775. :param wh_ratio_clip: clip threshold of wh_ratio
  776. :return:
  777. """
  778. deltas = paddle.reshape(deltas, [-1, 5])
  779. rrois = paddle.reshape(rrois, [-1, 5])
  780. # fix dy2st bug denorm_deltas = deltas * self.stds + self.means
  781. denorm_deltas = paddle.add(
  782. paddle.multiply(deltas, self.stds), self.means)
  783. dx = denorm_deltas[:, 0]
  784. dy = denorm_deltas[:, 1]
  785. dw = denorm_deltas[:, 2]
  786. dh = denorm_deltas[:, 3]
  787. dangle = denorm_deltas[:, 4]
  788. max_ratio = np.abs(np.log(wh_ratio_clip))
  789. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  790. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  791. rroi_x = rrois[:, 0]
  792. rroi_y = rrois[:, 1]
  793. rroi_w = rrois[:, 2]
  794. rroi_h = rrois[:, 3]
  795. rroi_angle = rrois[:, 4]
  796. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  797. rroi_angle) + rroi_x
  798. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  799. rroi_angle) + rroi_y
  800. gw = rroi_w * dw.exp()
  801. gh = rroi_h * dh.exp()
  802. ga = np.pi * dangle + rroi_angle
  803. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  804. ga = paddle.to_tensor(ga)
  805. gw = paddle.to_tensor(gw, dtype='float32')
  806. gh = paddle.to_tensor(gh, dtype='float32')
  807. bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
  808. return bboxes
  809. def bbox_decode(self, bbox_preds, anchors):
  810. """decode bbox from deltas
  811. Args:
  812. bbox_preds: [N,H,W,5]
  813. anchors: [H*W,5]
  814. return:
  815. bboxes: [N,H,W,5]
  816. """
  817. num_imgs, H, W, _ = bbox_preds.shape
  818. bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
  819. bboxes = self.delta2rbox(anchors, bbox_delta)
  820. return bboxes
  821. def trace(self, A):
  822. tr = paddle.diagonal(A, axis1=-2, axis2=-1)
  823. tr = paddle.sum(tr, axis=-1)
  824. return tr
  825. def sqrt_newton_schulz_autograd(self, A, numIters):
  826. A_shape = A.shape
  827. batchSize = A_shape[0]
  828. dim = A_shape[1]
  829. normA = A * A
  830. normA = paddle.sum(normA, axis=1)
  831. normA = paddle.sum(normA, axis=1)
  832. normA = paddle.sqrt(normA)
  833. normA1 = normA.reshape([batchSize, 1, 1])
  834. Y = paddle.divide(A, paddle.expand_as(normA1, A))
  835. I = paddle.eye(dim, dim).reshape([1, dim, dim])
  836. l0 = []
  837. for i in range(batchSize):
  838. l0.append(I)
  839. I = paddle.concat(l0, axis=0)
  840. I.stop_gradient = False
  841. Z = paddle.eye(dim, dim).reshape([1, dim, dim])
  842. l1 = []
  843. for i in range(batchSize):
  844. l1.append(Z)
  845. Z = paddle.concat(l1, axis=0)
  846. Z.stop_gradient = False
  847. for i in range(numIters):
  848. T = 0.5 * (3.0 * I - Z.bmm(Y))
  849. Y = Y.bmm(T)
  850. Z = T.bmm(Z)
  851. sA = Y * paddle.sqrt(normA1).reshape([batchSize, 1, 1])
  852. sA = paddle.expand_as(sA, A)
  853. return sA
  854. def wasserstein_distance_sigma(sigma1, sigma2):
  855. wasserstein_distance_item2 = paddle.matmul(
  856. sigma1, sigma1) + paddle.matmul(
  857. sigma2, sigma2) - 2 * self.sqrt_newton_schulz_autograd(
  858. paddle.matmul(
  859. paddle.matmul(sigma1, paddle.matmul(sigma2, sigma2)),
  860. sigma1), 10)
  861. wasserstein_distance_item2 = self.trace(wasserstein_distance_item2)
  862. return wasserstein_distance_item2
  863. def xywhr2xyrs(self, xywhr):
  864. xywhr = paddle.reshape(xywhr, [-1, 5])
  865. xy = xywhr[:, :2]
  866. wh = paddle.clip(xywhr[:, 2:4], min=1e-7, max=1e7)
  867. r = xywhr[:, 4]
  868. cos_r = paddle.cos(r)
  869. sin_r = paddle.sin(r)
  870. R = paddle.stack(
  871. (cos_r, -sin_r, sin_r, cos_r), axis=-1).reshape([-1, 2, 2])
  872. S = 0.5 * paddle.nn.functional.diag_embed(wh)
  873. return xy, R, S
  874. def gwd_loss(self,
  875. pred,
  876. target,
  877. fun='log',
  878. tau=1.0,
  879. alpha=1.0,
  880. normalize=False):
  881. xy_p, R_p, S_p = self.xywhr2xyrs(pred)
  882. xy_t, R_t, S_t = self.xywhr2xyrs(target)
  883. xy_distance = (xy_p - xy_t).square().sum(axis=-1)
  884. Sigma_p = R_p.matmul(S_p.square()).matmul(R_p.transpose([0, 2, 1]))
  885. Sigma_t = R_t.matmul(S_t.square()).matmul(R_t.transpose([0, 2, 1]))
  886. whr_distance = paddle.diagonal(
  887. S_p, axis1=-2, axis2=-1).square().sum(axis=-1)
  888. whr_distance = whr_distance + paddle.diagonal(
  889. S_t, axis1=-2, axis2=-1).square().sum(axis=-1)
  890. _t = Sigma_p.matmul(Sigma_t)
  891. _t_tr = paddle.diagonal(_t, axis1=-2, axis2=-1).sum(axis=-1)
  892. _t_det_sqrt = paddle.diagonal(S_p, axis1=-2, axis2=-1).prod(axis=-1)
  893. _t_det_sqrt = _t_det_sqrt * paddle.diagonal(
  894. S_t, axis1=-2, axis2=-1).prod(axis=-1)
  895. whr_distance = whr_distance + (-2) * (
  896. (_t_tr + 2 * _t_det_sqrt).clip(0).sqrt())
  897. distance = (xy_distance + alpha * alpha * whr_distance).clip(0)
  898. if normalize:
  899. wh_p = pred[..., 2:4].clip(min=1e-7, max=1e7)
  900. wh_t = target[..., 2:4].clip(min=1e-7, max=1e7)
  901. scale = ((wh_p.log() + wh_t.log()).sum(dim=-1) / 4).exp()
  902. distance = distance / scale
  903. if fun == 'log':
  904. distance = paddle.log1p(distance)
  905. if tau >= 1.0:
  906. return 1 - 1 / (tau + distance)
  907. return distance