s2anet_head.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import ParamAttr
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle.nn.initializer import Normal, Constant
  19. from paddlex.ppdet.core.workspace import register
  20. from paddlex.ppdet.modeling import ops
  21. from paddlex.ppdet.modeling import bbox_utils
  22. from paddlex.ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
  23. import numpy as np
  24. class S2ANetAnchorGenerator(nn.Layer):
  25. """
  26. AnchorGenerator by paddle
  27. """
  28. def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
  29. super(S2ANetAnchorGenerator, self).__init__()
  30. self.base_size = base_size
  31. self.scales = paddle.to_tensor(scales)
  32. self.ratios = paddle.to_tensor(ratios)
  33. self.scale_major = scale_major
  34. self.ctr = ctr
  35. self.base_anchors = self.gen_base_anchors()
  36. @property
  37. def num_base_anchors(self):
  38. return self.base_anchors.shape[0]
  39. def gen_base_anchors(self):
  40. w = self.base_size
  41. h = self.base_size
  42. if self.ctr is None:
  43. x_ctr = 0.5 * (w - 1)
  44. y_ctr = 0.5 * (h - 1)
  45. else:
  46. x_ctr, y_ctr = self.ctr
  47. h_ratios = paddle.sqrt(self.ratios)
  48. w_ratios = 1 / h_ratios
  49. if self.scale_major:
  50. ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
  51. hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
  52. else:
  53. ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
  54. hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
  55. base_anchors = paddle.stack(
  56. [
  57. x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
  58. x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
  59. ],
  60. axis=-1)
  61. base_anchors = paddle.round(base_anchors)
  62. return base_anchors
  63. def _meshgrid(self, x, y, row_major=True):
  64. yy, xx = paddle.meshgrid(y, x)
  65. yy = yy.reshape([-1])
  66. xx = xx.reshape([-1])
  67. if row_major:
  68. return xx, yy
  69. else:
  70. return yy, xx
  71. def forward(self, featmap_size, stride=16):
  72. # featmap_size*stride project it to original area
  73. feat_h = featmap_size[0]
  74. feat_w = featmap_size[1]
  75. shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
  76. shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
  77. shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
  78. shifts = paddle.stack(
  79. [shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
  80. all_anchors = self.base_anchors[:, :] + shifts[:, :]
  81. all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
  82. return all_anchors
  83. def valid_flags(self, featmap_size, valid_size):
  84. feat_h, feat_w = featmap_size
  85. valid_h, valid_w = valid_size
  86. assert valid_h <= feat_h and valid_w <= feat_w
  87. valid_x = paddle.zeros([feat_w], dtype='int32')
  88. valid_y = paddle.zeros([feat_h], dtype='int32')
  89. valid_x[:valid_w] = 1
  90. valid_y[:valid_h] = 1
  91. valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
  92. valid = valid_xx & valid_yy
  93. valid = paddle.reshape(valid, [-1, 1])
  94. valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1])
  95. return valid
  96. class AlignConv(nn.Layer):
  97. def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
  98. super(AlignConv, self).__init__()
  99. self.kernel_size = kernel_size
  100. self.align_conv = paddle.vision.ops.DeformConv2D(
  101. in_channels,
  102. out_channels,
  103. kernel_size=self.kernel_size,
  104. padding=(self.kernel_size - 1) // 2,
  105. groups=groups,
  106. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  107. bias_attr=None)
  108. @paddle.no_grad()
  109. def get_offset(self, anchors, featmap_size, stride):
  110. """
  111. Args:
  112. anchors: [M,5] xc,yc,w,h,angle
  113. featmap_size: (feat_h, feat_w)
  114. stride: 8
  115. Returns:
  116. """
  117. anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
  118. dtype = anchors.dtype
  119. feat_h = featmap_size[0]
  120. feat_w = featmap_size[1]
  121. pad = (self.kernel_size - 1) // 2
  122. idx = paddle.arange(-pad, pad + 1, dtype=dtype)
  123. yy, xx = paddle.meshgrid(idx, idx)
  124. xx = paddle.reshape(xx, [-1])
  125. yy = paddle.reshape(yy, [-1])
  126. # get sampling locations of default conv
  127. xc = paddle.arange(0, feat_w, dtype=dtype)
  128. yc = paddle.arange(0, feat_h, dtype=dtype)
  129. yc, xc = paddle.meshgrid(yc, xc)
  130. xc = paddle.reshape(xc, [-1, 1])
  131. yc = paddle.reshape(yc, [-1, 1])
  132. x_conv = xc + xx
  133. y_conv = yc + yy
  134. # get sampling locations of anchors
  135. # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1)
  136. x_ctr = anchors[:, 0]
  137. y_ctr = anchors[:, 1]
  138. w = anchors[:, 2]
  139. h = anchors[:, 3]
  140. a = anchors[:, 4]
  141. x_ctr = paddle.reshape(x_ctr, [-1, 1])
  142. y_ctr = paddle.reshape(y_ctr, [-1, 1])
  143. w = paddle.reshape(w, [-1, 1])
  144. h = paddle.reshape(h, [-1, 1])
  145. a = paddle.reshape(a, [-1, 1])
  146. x_ctr = x_ctr / stride
  147. y_ctr = y_ctr / stride
  148. w_s = w / stride
  149. h_s = h / stride
  150. cos, sin = paddle.cos(a), paddle.sin(a)
  151. dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
  152. x, y = dw * xx, dh * yy
  153. xr = cos * x - sin * y
  154. yr = sin * x + cos * y
  155. x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
  156. # get offset filed
  157. offset_x = x_anchor - x_conv
  158. offset_y = y_anchor - y_conv
  159. offset = paddle.stack([offset_y, offset_x], axis=-1)
  160. offset = paddle.reshape(
  161. offset,
  162. [feat_h * feat_w, self.kernel_size * self.kernel_size * 2])
  163. offset = paddle.transpose(offset, [1, 0])
  164. offset = paddle.reshape(
  165. offset,
  166. [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w])
  167. return offset
  168. def forward(self, x, refine_anchors, featmap_size, stride):
  169. offset = self.get_offset(refine_anchors, featmap_size, stride)
  170. x = F.relu(self.align_conv(x, offset))
  171. return x
  172. @register
  173. class S2ANetHead(nn.Layer):
  174. """
  175. S2Anet head
  176. Args:
  177. stacked_convs (int): number of stacked_convs
  178. feat_in (int): input channels of feat
  179. feat_out (int): output channels of feat
  180. num_classes (int): num_classes
  181. anchor_strides (list): stride of anchors
  182. anchor_scales (list): scale of anchors
  183. anchor_ratios (list): ratios of anchors
  184. target_means (list): target_means
  185. target_stds (list): target_stds
  186. align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
  187. align_conv_size (int): kernel size of align_conv
  188. use_sigmoid_cls (bool): use sigmoid_cls or not
  189. reg_loss_weight (list): loss weight for regression
  190. """
  191. __shared__ = ['num_classes']
  192. __inject__ = ['anchor_assign']
  193. def __init__(self,
  194. stacked_convs=2,
  195. feat_in=256,
  196. feat_out=256,
  197. num_classes=15,
  198. anchor_strides=[8, 16, 32, 64, 128],
  199. anchor_scales=[4],
  200. anchor_ratios=[1.0],
  201. target_means=0.0,
  202. target_stds=1.0,
  203. align_conv_type='AlignConv',
  204. align_conv_size=3,
  205. use_sigmoid_cls=True,
  206. anchor_assign=RBoxAssigner().__dict__,
  207. reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
  208. cls_loss_weight=[1.1, 1.05],
  209. reg_loss_type='l1',
  210. is_training=True):
  211. super(S2ANetHead, self).__init__()
  212. self.stacked_convs = stacked_convs
  213. self.feat_in = feat_in
  214. self.feat_out = feat_out
  215. self.anchor_list = None
  216. self.anchor_scales = anchor_scales
  217. self.anchor_ratios = anchor_ratios
  218. self.anchor_strides = anchor_strides
  219. self.anchor_strides = paddle.to_tensor(anchor_strides)
  220. self.anchor_base_sizes = list(anchor_strides)
  221. self.means = paddle.ones(shape=[5]) * target_means
  222. self.stds = paddle.ones(shape=[5]) * target_stds
  223. assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
  224. self.align_conv_type = align_conv_type
  225. self.align_conv_size = align_conv_size
  226. self.use_sigmoid_cls = use_sigmoid_cls
  227. self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1
  228. self.sampling = False
  229. self.anchor_assign = anchor_assign
  230. self.reg_loss_weight = reg_loss_weight
  231. self.cls_loss_weight = cls_loss_weight
  232. self.alpha = 1.0
  233. self.beta = 1.0
  234. self.reg_loss_type = reg_loss_type
  235. self.is_training = is_training
  236. self.s2anet_head_out = None
  237. # anchor
  238. self.anchor_generators = []
  239. for anchor_base in self.anchor_base_sizes:
  240. self.anchor_generators.append(
  241. S2ANetAnchorGenerator(anchor_base, anchor_scales,
  242. anchor_ratios))
  243. self.anchor_generators = nn.LayerList(self.anchor_generators)
  244. self.fam_cls_convs = nn.Sequential()
  245. self.fam_reg_convs = nn.Sequential()
  246. for i in range(self.stacked_convs):
  247. chan_in = self.feat_in if i == 0 else self.feat_out
  248. self.fam_cls_convs.add_sublayer(
  249. 'fam_cls_conv_{}'.format(i),
  250. nn.Conv2D(
  251. in_channels=chan_in,
  252. out_channels=self.feat_out,
  253. kernel_size=3,
  254. padding=1,
  255. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  256. bias_attr=ParamAttr(initializer=Constant(0))))
  257. self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
  258. nn.ReLU())
  259. self.fam_reg_convs.add_sublayer(
  260. 'fam_reg_conv_{}'.format(i),
  261. nn.Conv2D(
  262. in_channels=chan_in,
  263. out_channels=self.feat_out,
  264. kernel_size=3,
  265. padding=1,
  266. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  267. bias_attr=ParamAttr(initializer=Constant(0))))
  268. self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
  269. nn.ReLU())
  270. self.fam_reg = nn.Conv2D(
  271. self.feat_out,
  272. 5,
  273. 1,
  274. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  275. bias_attr=ParamAttr(initializer=Constant(0)))
  276. prior_prob = 0.01
  277. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  278. self.fam_cls = nn.Conv2D(
  279. self.feat_out,
  280. self.cls_out_channels,
  281. 1,
  282. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  283. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  284. if self.align_conv_type == "AlignConv":
  285. self.align_conv = AlignConv(self.feat_out, self.feat_out,
  286. self.align_conv_size)
  287. elif self.align_conv_type == "Conv":
  288. self.align_conv = nn.Conv2D(
  289. self.feat_out,
  290. self.feat_out,
  291. self.align_conv_size,
  292. padding=(self.align_conv_size - 1) // 2,
  293. bias_attr=ParamAttr(initializer=Constant(0)))
  294. elif self.align_conv_type == "DCN":
  295. self.align_conv_offset = nn.Conv2D(
  296. self.feat_out,
  297. 2 * self.align_conv_size**2,
  298. 1,
  299. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  300. bias_attr=ParamAttr(initializer=Constant(0)))
  301. self.align_conv = paddle.vision.ops.DeformConv2D(
  302. self.feat_out,
  303. self.feat_out,
  304. self.align_conv_size,
  305. padding=(self.align_conv_size - 1) // 2,
  306. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  307. bias_attr=False)
  308. self.or_conv = nn.Conv2D(
  309. self.feat_out,
  310. self.feat_out,
  311. kernel_size=3,
  312. padding=1,
  313. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  314. bias_attr=ParamAttr(initializer=Constant(0)))
  315. # ODM
  316. self.odm_cls_convs = nn.Sequential()
  317. self.odm_reg_convs = nn.Sequential()
  318. for i in range(self.stacked_convs):
  319. ch_in = self.feat_out
  320. # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
  321. self.odm_cls_convs.add_sublayer(
  322. 'odm_cls_conv_{}'.format(i),
  323. nn.Conv2D(
  324. in_channels=ch_in,
  325. out_channels=self.feat_out,
  326. kernel_size=3,
  327. stride=1,
  328. padding=1,
  329. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  330. bias_attr=ParamAttr(initializer=Constant(0))))
  331. self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
  332. nn.ReLU())
  333. self.odm_reg_convs.add_sublayer(
  334. 'odm_reg_conv_{}'.format(i),
  335. nn.Conv2D(
  336. in_channels=self.feat_out,
  337. out_channels=self.feat_out,
  338. kernel_size=3,
  339. stride=1,
  340. padding=1,
  341. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  342. bias_attr=ParamAttr(initializer=Constant(0))))
  343. self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
  344. nn.ReLU())
  345. self.odm_cls = nn.Conv2D(
  346. self.feat_out,
  347. self.cls_out_channels,
  348. 3,
  349. padding=1,
  350. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  351. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  352. self.odm_reg = nn.Conv2D(
  353. self.feat_out,
  354. 5,
  355. 3,
  356. padding=1,
  357. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  358. bias_attr=ParamAttr(initializer=Constant(0)))
  359. self.featmap_sizes = []
  360. self.base_anchors_list = []
  361. self.refine_anchor_list = []
  362. def forward(self, feats):
  363. fam_reg_branch_list = []
  364. fam_cls_branch_list = []
  365. odm_reg_branch_list = []
  366. odm_cls_branch_list = []
  367. self.featmap_sizes_list = []
  368. self.base_anchors_list = []
  369. self.refine_anchor_list = []
  370. for feat_idx in range(len(feats)):
  371. feat = feats[feat_idx]
  372. fam_cls_feat = self.fam_cls_convs(feat)
  373. fam_cls = self.fam_cls(fam_cls_feat)
  374. # [N, CLS, H, W] --> [N, H, W, CLS]
  375. fam_cls = fam_cls.transpose([0, 2, 3, 1])
  376. fam_cls_reshape = paddle.reshape(
  377. fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels])
  378. fam_cls_branch_list.append(fam_cls_reshape)
  379. fam_reg_feat = self.fam_reg_convs(feat)
  380. fam_reg = self.fam_reg(fam_reg_feat)
  381. # [N, 5, H, W] --> [N, H, W, 5]
  382. fam_reg = fam_reg.transpose([0, 2, 3, 1])
  383. fam_reg_reshape = paddle.reshape(fam_reg,
  384. [fam_reg.shape[0], -1, 5])
  385. fam_reg_branch_list.append(fam_reg_reshape)
  386. # prepare anchor
  387. featmap_size = (paddle.shape(feat)[2], paddle.shape(feat)[3])
  388. self.featmap_sizes_list.append(featmap_size)
  389. init_anchors = self.anchor_generators[feat_idx](
  390. featmap_size, self.anchor_strides[feat_idx])
  391. init_anchors = paddle.to_tensor(init_anchors, dtype='float32')
  392. NA = featmap_size[0] * featmap_size[1]
  393. init_anchors = paddle.reshape(init_anchors, [NA, 4])
  394. init_anchors = self.rect2rbox(init_anchors)
  395. self.base_anchors_list.append(init_anchors)
  396. if self.is_training:
  397. refine_anchor = self.bbox_decode(fam_reg.detach(),
  398. init_anchors)
  399. else:
  400. fam_reg1 = fam_reg.clone()
  401. fam_reg1.stop_gradient = True
  402. refine_anchor = self.bbox_decode(fam_reg1, init_anchors)
  403. self.refine_anchor_list.append(refine_anchor)
  404. if self.align_conv_type == 'AlignConv':
  405. align_feat = self.align_conv(feat,
  406. refine_anchor.clone(),
  407. featmap_size,
  408. self.anchor_strides[feat_idx])
  409. elif self.align_conv_type == 'DCN':
  410. align_offset = self.align_conv_offset(feat)
  411. align_feat = self.align_conv(feat, align_offset)
  412. elif self.align_conv_type == 'Conv':
  413. align_feat = self.align_conv(feat)
  414. or_feat = self.or_conv(align_feat)
  415. odm_reg_feat = or_feat
  416. odm_cls_feat = or_feat
  417. odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
  418. odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
  419. odm_cls_score = self.odm_cls(odm_cls_feat)
  420. # [N, CLS, H, W] --> [N, H, W, CLS]
  421. odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
  422. odm_cls_score_shape = odm_cls_score.shape
  423. odm_cls_score_reshape = paddle.reshape(odm_cls_score, [
  424. odm_cls_score_shape[0], odm_cls_score_shape[1] *
  425. odm_cls_score_shape[2], self.cls_out_channels
  426. ])
  427. odm_cls_branch_list.append(odm_cls_score_reshape)
  428. odm_bbox_pred = self.odm_reg(odm_reg_feat)
  429. # [N, 5, H, W] --> [N, H, W, 5]
  430. odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
  431. odm_bbox_pred_reshape = paddle.reshape(odm_bbox_pred, [-1, 5])
  432. odm_bbox_pred_reshape = paddle.unsqueeze(
  433. odm_bbox_pred_reshape, axis=0)
  434. odm_reg_branch_list.append(odm_bbox_pred_reshape)
  435. self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
  436. odm_cls_branch_list, odm_reg_branch_list)
  437. return self.s2anet_head_out
  438. def get_prediction(self, nms_pre=2000):
  439. refine_anchors = self.refine_anchor_list
  440. fam_cls_branch_list = self.s2anet_head_out[0]
  441. fam_reg_branch_list = self.s2anet_head_out[1]
  442. odm_cls_branch_list = self.s2anet_head_out[2]
  443. odm_reg_branch_list = self.s2anet_head_out[3]
  444. pred_scores, pred_bboxes = self.get_bboxes(
  445. odm_cls_branch_list, odm_reg_branch_list, refine_anchors, nms_pre,
  446. self.cls_out_channels, self.use_sigmoid_cls)
  447. return pred_scores, pred_bboxes
  448. def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
  449. """
  450. Args:
  451. pred: pred score
  452. label: label
  453. delta: delta
  454. Returns: loss
  455. """
  456. assert pred.shape == label.shape and label.numel() > 0
  457. assert delta > 0
  458. diff = paddle.abs(pred - label)
  459. loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
  460. diff - 0.5 * delta)
  461. return loss
  462. def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'):
  463. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  464. pos_inds, neg_inds) = fam_target
  465. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  466. fam_cls_losses = []
  467. fam_bbox_losses = []
  468. st_idx = 0
  469. num_total_samples = len(pos_inds) + len(
  470. neg_inds) if self.sampling else len(pos_inds)
  471. num_total_samples = max(1, num_total_samples)
  472. for idx, feat_size in enumerate(self.featmap_sizes_list):
  473. feat_anchor_num = feat_size[0] * feat_size[1]
  474. # step1: get data
  475. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  476. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  477. feat_bbox_targets = bbox_targets[st_idx:st_idx +
  478. feat_anchor_num, :]
  479. feat_bbox_weights = bbox_weights[st_idx:st_idx +
  480. feat_anchor_num, :]
  481. # step2: calc cls loss
  482. feat_labels = feat_labels.reshape(-1)
  483. feat_label_weights = feat_label_weights.reshape(-1)
  484. fam_cls_score = fam_cls_branch_list[idx]
  485. fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
  486. fam_cls_score1 = fam_cls_score
  487. feat_labels = paddle.to_tensor(feat_labels)
  488. feat_labels_one_hot = paddle.nn.functional.one_hot(
  489. feat_labels, self.cls_out_channels + 1)
  490. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  491. feat_labels_one_hot.stop_gradient = True
  492. num_total_samples = paddle.to_tensor(
  493. num_total_samples, dtype='float32', stop_gradient=True)
  494. fam_cls = F.sigmoid_focal_loss(
  495. fam_cls_score1,
  496. feat_labels_one_hot,
  497. normalizer=num_total_samples,
  498. reduction='none')
  499. feat_label_weights = feat_label_weights.reshape(
  500. feat_label_weights.shape[0], 1)
  501. feat_label_weights = np.repeat(
  502. feat_label_weights, self.cls_out_channels, axis=1)
  503. feat_label_weights = paddle.to_tensor(
  504. feat_label_weights, stop_gradient=True)
  505. fam_cls = fam_cls * feat_label_weights
  506. fam_cls_total = paddle.sum(fam_cls)
  507. fam_cls_losses.append(fam_cls_total)
  508. # step3: regression loss
  509. feat_bbox_targets = paddle.to_tensor(
  510. feat_bbox_targets, dtype='float32', stop_gradient=True)
  511. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  512. fam_bbox_pred = fam_reg_branch_list[idx]
  513. fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
  514. fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
  515. fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
  516. loss_weight = paddle.to_tensor(
  517. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  518. fam_bbox = paddle.multiply(fam_bbox, loss_weight)
  519. feat_bbox_weights = paddle.to_tensor(
  520. feat_bbox_weights, stop_gradient=True)
  521. if reg_loss_type == 'l1':
  522. fam_bbox = fam_bbox * feat_bbox_weights
  523. fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
  524. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  525. fam_bbox = paddle.sum(fam_bbox, axis=-1)
  526. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  527. try:
  528. from rbox_iou_ops import rbox_iou
  529. except Exception as e:
  530. print("import custom_ops error, try install rbox_iou_ops " \
  531. "following ppdet/ext_op/README.md", e)
  532. sys.stdout.flush()
  533. sys.exit(-1)
  534. # calc iou
  535. fam_bbox_decode = self.delta2rbox(self.base_anchors_list[idx],
  536. fam_bbox_pred)
  537. bbox_gt_bboxes = paddle.to_tensor(
  538. bbox_gt_bboxes,
  539. dtype=fam_bbox_decode.dtype,
  540. place=fam_bbox_decode.place)
  541. bbox_gt_bboxes.stop_gradient = True
  542. iou = rbox_iou(fam_bbox_decode, bbox_gt_bboxes)
  543. iou = paddle.diag(iou)
  544. if reg_loss_type == 'gwd':
  545. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  546. feat_anchor_num, :]
  547. fam_bbox_total = self.gwd_loss(fam_bbox_decode,
  548. bbox_gt_bboxes_level)
  549. fam_bbox_total = fam_bbox_total * feat_bbox_weights
  550. fam_bbox_total = paddle.sum(
  551. fam_bbox_total) / num_total_samples
  552. fam_bbox_losses.append(fam_bbox_total)
  553. st_idx += feat_anchor_num
  554. fam_cls_loss = paddle.add_n(fam_cls_losses)
  555. fam_cls_loss_weight = paddle.to_tensor(
  556. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  557. fam_cls_loss = fam_cls_loss * fam_cls_loss_weight
  558. fam_reg_loss = paddle.add_n(fam_bbox_losses)
  559. return fam_cls_loss, fam_reg_loss
  560. def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
  561. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  562. pos_inds, neg_inds) = odm_target
  563. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  564. odm_cls_losses = []
  565. odm_bbox_losses = []
  566. st_idx = 0
  567. num_total_samples = len(pos_inds) + len(
  568. neg_inds) if self.sampling else len(pos_inds)
  569. num_total_samples = max(1, num_total_samples)
  570. for idx, feat_size in enumerate(self.featmap_sizes_list):
  571. feat_anchor_num = feat_size[0] * feat_size[1]
  572. # step1: get data
  573. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  574. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  575. feat_bbox_targets = bbox_targets[st_idx:st_idx +
  576. feat_anchor_num, :]
  577. feat_bbox_weights = bbox_weights[st_idx:st_idx +
  578. feat_anchor_num, :]
  579. # step2: calc cls loss
  580. feat_labels = feat_labels.reshape(-1)
  581. feat_label_weights = feat_label_weights.reshape(-1)
  582. odm_cls_score = odm_cls_branch_list[idx]
  583. odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
  584. odm_cls_score1 = odm_cls_score
  585. feat_labels = paddle.to_tensor(feat_labels)
  586. feat_labels_one_hot = paddle.nn.functional.one_hot(
  587. feat_labels, self.cls_out_channels + 1)
  588. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  589. feat_labels_one_hot.stop_gradient = True
  590. num_total_samples = paddle.to_tensor(
  591. num_total_samples, dtype='float32', stop_gradient=True)
  592. odm_cls = F.sigmoid_focal_loss(
  593. odm_cls_score1,
  594. feat_labels_one_hot,
  595. normalizer=num_total_samples,
  596. reduction='none')
  597. feat_label_weights = feat_label_weights.reshape(
  598. feat_label_weights.shape[0], 1)
  599. feat_label_weights = np.repeat(
  600. feat_label_weights, self.cls_out_channels, axis=1)
  601. feat_label_weights = paddle.to_tensor(feat_label_weights)
  602. feat_label_weights.stop_gradient = True
  603. odm_cls = odm_cls * feat_label_weights
  604. odm_cls_total = paddle.sum(odm_cls)
  605. odm_cls_losses.append(odm_cls_total)
  606. # # step3: regression loss
  607. feat_bbox_targets = paddle.to_tensor(
  608. feat_bbox_targets, dtype='float32')
  609. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  610. feat_bbox_targets.stop_gradient = True
  611. odm_bbox_pred = odm_reg_branch_list[idx]
  612. odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
  613. odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
  614. odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
  615. loss_weight = paddle.to_tensor(
  616. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  617. odm_bbox = paddle.multiply(odm_bbox, loss_weight)
  618. feat_bbox_weights = paddle.to_tensor(
  619. feat_bbox_weights, stop_gradient=True)
  620. if reg_loss_type == 'l1':
  621. odm_bbox = odm_bbox * feat_bbox_weights
  622. odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
  623. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  624. odm_bbox = paddle.sum(odm_bbox, axis=-1)
  625. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  626. try:
  627. from rbox_iou_ops import rbox_iou
  628. except Exception as e:
  629. print("import custom_ops error, try install rbox_iou_ops " \
  630. "following ppdet/ext_op/README.md", e)
  631. sys.stdout.flush()
  632. sys.exit(-1)
  633. # calc iou
  634. odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx],
  635. odm_bbox_pred)
  636. bbox_gt_bboxes = paddle.to_tensor(
  637. bbox_gt_bboxes,
  638. dtype=odm_bbox_decode.dtype,
  639. place=odm_bbox_decode.place)
  640. bbox_gt_bboxes.stop_gradient = True
  641. iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes)
  642. iou = paddle.diag(iou)
  643. if reg_loss_type == 'gwd':
  644. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  645. feat_anchor_num, :]
  646. odm_bbox_total = self.gwd_loss(odm_bbox_decode,
  647. bbox_gt_bboxes_level)
  648. odm_bbox_total = odm_bbox_total * feat_bbox_weights
  649. odm_bbox_total = paddle.sum(
  650. odm_bbox_total) / num_total_samples
  651. odm_bbox_losses.append(odm_bbox_total)
  652. st_idx += feat_anchor_num
  653. odm_cls_loss = paddle.add_n(odm_cls_losses)
  654. odm_cls_loss_weight = paddle.to_tensor(
  655. self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
  656. odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
  657. odm_reg_loss = paddle.add_n(odm_bbox_losses)
  658. return odm_cls_loss, odm_reg_loss
  659. def get_loss(self, inputs):
  660. # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd
  661. # compute loss
  662. fam_cls_loss_lst = []
  663. fam_reg_loss_lst = []
  664. odm_cls_loss_lst = []
  665. odm_reg_loss_lst = []
  666. im_shape = inputs['im_shape']
  667. for im_id in range(im_shape.shape[0]):
  668. np_im_shape = inputs['im_shape'][im_id].numpy()
  669. np_scale_factor = inputs['scale_factor'][im_id].numpy()
  670. # data_format: (xc, yc, w, h, theta)
  671. gt_bboxes = inputs['gt_rbox'][im_id].numpy()
  672. gt_labels = inputs['gt_class'][im_id].numpy()
  673. is_crowd = inputs['is_crowd'][im_id].numpy()
  674. gt_labels = gt_labels + 1
  675. # featmap_sizes
  676. anchors_list_all = np.concatenate(self.base_anchors_list)
  677. # get im_feat
  678. fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
  679. fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
  680. odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
  681. odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
  682. im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
  683. odm_cls_feats_list, odm_reg_feats_list)
  684. # FAM
  685. im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes,
  686. gt_labels, is_crowd)
  687. if im_fam_target is not None:
  688. im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
  689. im_fam_target, im_s2anet_head_out, self.reg_loss_type)
  690. fam_cls_loss_lst.append(im_fam_cls_loss)
  691. fam_reg_loss_lst.append(im_fam_reg_loss)
  692. # ODM
  693. np_refine_anchors_list = paddle.concat(
  694. self.refine_anchor_list).numpy()
  695. np_refine_anchors_list = np.concatenate(np_refine_anchors_list)
  696. np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5)
  697. im_odm_target = self.anchor_assign(np_refine_anchors_list,
  698. gt_bboxes, gt_labels, is_crowd)
  699. if im_odm_target is not None:
  700. im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
  701. im_odm_target, im_s2anet_head_out, self.reg_loss_type)
  702. odm_cls_loss_lst.append(im_odm_cls_loss)
  703. odm_reg_loss_lst.append(im_odm_reg_loss)
  704. fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
  705. fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
  706. odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
  707. odm_reg_loss = paddle.add_n(odm_reg_loss_lst)
  708. return {
  709. 'fam_cls_loss': fam_cls_loss,
  710. 'fam_reg_loss': fam_reg_loss,
  711. 'odm_cls_loss': odm_cls_loss,
  712. 'odm_reg_loss': odm_reg_loss
  713. }
  714. def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
  715. cls_out_channels, use_sigmoid_cls):
  716. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
  717. mlvl_bboxes = []
  718. mlvl_scores = []
  719. idx = 0
  720. for cls_score, bbox_pred, anchors in zip(cls_score_list,
  721. bbox_pred_list, mlvl_anchors):
  722. cls_score = paddle.reshape(cls_score, [-1, cls_out_channels])
  723. if use_sigmoid_cls:
  724. scores = F.sigmoid(cls_score)
  725. else:
  726. scores = F.softmax(cls_score, axis=-1)
  727. # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
  728. bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0])
  729. bbox_pred = paddle.reshape(bbox_pred, [-1, 5])
  730. anchors = paddle.reshape(anchors, [-1, 5])
  731. if scores.shape[0] > nms_pre:
  732. # Get maximum scores for foreground classes.
  733. if use_sigmoid_cls:
  734. max_scores = paddle.max(scores, axis=1)
  735. else:
  736. max_scores = paddle.max(scores[:, 1:], axis=1)
  737. topk_val, topk_inds = paddle.topk(max_scores, nms_pre)
  738. anchors = paddle.gather(anchors, topk_inds)
  739. bbox_pred = paddle.gather(bbox_pred, topk_inds)
  740. scores = paddle.gather(scores, topk_inds)
  741. bbox_delta = paddle.reshape(bbox_pred, [-1, 5])
  742. bboxes = self.delta2rbox(anchors, bbox_delta)
  743. mlvl_bboxes.append(bboxes)
  744. mlvl_scores.append(scores)
  745. idx += 1
  746. mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
  747. mlvl_scores = paddle.concat(mlvl_scores)
  748. return mlvl_scores, mlvl_bboxes
  749. def rect2rbox(self, bboxes):
  750. """
  751. :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
  752. :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
  753. """
  754. bboxes = paddle.reshape(bboxes, [-1, 4])
  755. num_boxes = paddle.shape(bboxes)[0]
  756. x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
  757. y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
  758. edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
  759. edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
  760. rbox_w = paddle.maximum(edges1, edges2)
  761. rbox_h = paddle.minimum(edges1, edges2)
  762. # set angle
  763. inds = edges1 < edges2
  764. inds = paddle.cast(inds, 'int32')
  765. rboxes_angle = inds * np.pi / 2.0
  766. rboxes = paddle.stack(
  767. (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
  768. return rboxes
  769. # deltas to rbox
  770. def delta2rbox(self, rrois, deltas, wh_ratio_clip=1e-6):
  771. """
  772. :param rrois: (cx, cy, w, h, theta)
  773. :param deltas: (dx, dy, dw, dh, dtheta)
  774. :param means: means of anchor
  775. :param stds: stds of anchor
  776. :param wh_ratio_clip: clip threshold of wh_ratio
  777. :return:
  778. """
  779. deltas = paddle.reshape(deltas, [-1, 5])
  780. rrois = paddle.reshape(rrois, [-1, 5])
  781. # fix dy2st bug denorm_deltas = deltas * self.stds + self.means
  782. denorm_deltas = paddle.add(
  783. paddle.multiply(deltas, self.stds), self.means)
  784. dx = denorm_deltas[:, 0]
  785. dy = denorm_deltas[:, 1]
  786. dw = denorm_deltas[:, 2]
  787. dh = denorm_deltas[:, 3]
  788. dangle = denorm_deltas[:, 4]
  789. max_ratio = np.abs(np.log(wh_ratio_clip))
  790. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  791. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  792. rroi_x = rrois[:, 0]
  793. rroi_y = rrois[:, 1]
  794. rroi_w = rrois[:, 2]
  795. rroi_h = rrois[:, 3]
  796. rroi_angle = rrois[:, 4]
  797. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  798. rroi_angle) + rroi_x
  799. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  800. rroi_angle) + rroi_y
  801. gw = rroi_w * dw.exp()
  802. gh = rroi_h * dh.exp()
  803. ga = np.pi * dangle + rroi_angle
  804. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  805. ga = paddle.to_tensor(ga)
  806. gw = paddle.to_tensor(gw, dtype='float32')
  807. gh = paddle.to_tensor(gh, dtype='float32')
  808. bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
  809. return bboxes
  810. def bbox_decode(self, bbox_preds, anchors):
  811. """decode bbox from deltas
  812. Args:
  813. bbox_preds: [N,H,W,5]
  814. anchors: [H*W,5]
  815. return:
  816. bboxes: [N,H,W,5]
  817. """
  818. num_imgs, H, W, _ = bbox_preds.shape
  819. bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
  820. bboxes = self.delta2rbox(anchors, bbox_delta)
  821. return bboxes
  822. def trace(self, A):
  823. tr = paddle.diagonal(A, axis1=-2, axis2=-1)
  824. tr = paddle.sum(tr, axis=-1)
  825. return tr
  826. def sqrt_newton_schulz_autograd(self, A, numIters):
  827. A_shape = A.shape
  828. batchSize = A_shape[0]
  829. dim = A_shape[1]
  830. normA = A * A
  831. normA = paddle.sum(normA, axis=1)
  832. normA = paddle.sum(normA, axis=1)
  833. normA = paddle.sqrt(normA)
  834. normA1 = normA.reshape([batchSize, 1, 1])
  835. Y = paddle.divide(A, paddle.expand_as(normA1, A))
  836. I = paddle.eye(dim, dim).reshape([1, dim, dim])
  837. l0 = []
  838. for i in range(batchSize):
  839. l0.append(I)
  840. I = paddle.concat(l0, axis=0)
  841. I.stop_gradient = False
  842. Z = paddle.eye(dim, dim).reshape([1, dim, dim])
  843. l1 = []
  844. for i in range(batchSize):
  845. l1.append(Z)
  846. Z = paddle.concat(l1, axis=0)
  847. Z.stop_gradient = False
  848. for i in range(numIters):
  849. T = 0.5 * (3.0 * I - Z.bmm(Y))
  850. Y = Y.bmm(T)
  851. Z = T.bmm(Z)
  852. sA = Y * paddle.sqrt(normA1).reshape([batchSize, 1, 1])
  853. sA = paddle.expand_as(sA, A)
  854. return sA
  855. def wasserstein_distance_sigma(sigma1, sigma2):
  856. wasserstein_distance_item2 = paddle.matmul(
  857. sigma1, sigma1) + paddle.matmul(
  858. sigma2, sigma2) - 2 * self.sqrt_newton_schulz_autograd(
  859. paddle.matmul(
  860. paddle.matmul(sigma1, paddle.matmul(sigma2, sigma2)),
  861. sigma1), 10)
  862. wasserstein_distance_item2 = self.trace(wasserstein_distance_item2)
  863. return wasserstein_distance_item2
  864. def xywhr2xyrs(self, xywhr):
  865. xywhr = paddle.reshape(xywhr, [-1, 5])
  866. xy = xywhr[:, :2]
  867. wh = paddle.clip(xywhr[:, 2:4], min=1e-7, max=1e7)
  868. r = xywhr[:, 4]
  869. cos_r = paddle.cos(r)
  870. sin_r = paddle.sin(r)
  871. R = paddle.stack(
  872. (cos_r, -sin_r, sin_r, cos_r), axis=-1).reshape([-1, 2, 2])
  873. S = 0.5 * paddle.nn.functional.diag_embed(wh)
  874. return xy, R, S
  875. def gwd_loss(self,
  876. pred,
  877. target,
  878. fun='log',
  879. tau=1.0,
  880. alpha=1.0,
  881. normalize=False):
  882. xy_p, R_p, S_p = self.xywhr2xyrs(pred)
  883. xy_t, R_t, S_t = self.xywhr2xyrs(target)
  884. xy_distance = (xy_p - xy_t).square().sum(axis=-1)
  885. Sigma_p = R_p.matmul(S_p.square()).matmul(R_p.transpose([0, 2, 1]))
  886. Sigma_t = R_t.matmul(S_t.square()).matmul(R_t.transpose([0, 2, 1]))
  887. whr_distance = paddle.diagonal(
  888. S_p, axis1=-2, axis2=-1).square().sum(axis=-1)
  889. whr_distance = whr_distance + paddle.diagonal(
  890. S_t, axis1=-2, axis2=-1).square().sum(axis=-1)
  891. _t = Sigma_p.matmul(Sigma_t)
  892. _t_tr = paddle.diagonal(_t, axis1=-2, axis2=-1).sum(axis=-1)
  893. _t_det_sqrt = paddle.diagonal(S_p, axis1=-2, axis2=-1).prod(axis=-1)
  894. _t_det_sqrt = _t_det_sqrt * paddle.diagonal(
  895. S_t, axis1=-2, axis2=-1).prod(axis=-1)
  896. whr_distance = whr_distance + (-2) * (
  897. (_t_tr + 2 * _t_det_sqrt).clip(0).sqrt())
  898. distance = (xy_distance + alpha * alpha * whr_distance).clip(0)
  899. if normalize:
  900. wh_p = pred[..., 2:4].clip(min=1e-7, max=1e7)
  901. wh_t = target[..., 2:4].clip(min=1e-7, max=1e7)
  902. scale = ((wh_p.log() + wh_t.log()).sum(dim=-1) / 4).exp()
  903. distance = distance / scale
  904. if fun == 'log':
  905. distance = paddle.log1p(distance)
  906. if tau >= 1.0:
  907. return 1 - 1 / (tau + distance)
  908. return distance