layers.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. import paddle.nn.functional as F
  23. from paddle.nn.initializer import Normal, Constant, XavierUniform
  24. from paddle.regularizer import L2Decay
  25. from paddlex.ppdet.core.workspace import register, serializable
  26. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  27. from . import ops
  28. from .initializer import xavier_uniform_, constant_
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. dcn_bias_regularizer=dcn_regularizer,
  152. dcn_bias_lr_scale=dcn_lr_scale,
  153. skip_quant=skip_quant)
  154. norm_lr = 0. if freeze_norm else 1.
  155. param_attr = ParamAttr(
  156. learning_rate=norm_lr,
  157. regularizer=L2Decay(norm_decay)
  158. if norm_decay is not None else None)
  159. bias_attr = ParamAttr(
  160. learning_rate=norm_lr,
  161. regularizer=L2Decay(norm_decay)
  162. if norm_decay is not None else None)
  163. if norm_type in ['bn', 'sync_bn']:
  164. self.norm = nn.BatchNorm2D(
  165. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  166. elif norm_type == 'gn':
  167. self.norm = nn.GroupNorm(
  168. num_groups=norm_groups,
  169. num_channels=ch_out,
  170. weight_attr=param_attr,
  171. bias_attr=bias_attr)
  172. else:
  173. self.norm = None
  174. def forward(self, inputs):
  175. out = self.conv(inputs)
  176. if self.norm is not None:
  177. out = self.norm(out)
  178. return out
  179. class LiteConv(nn.Layer):
  180. def __init__(self,
  181. in_channels,
  182. out_channels,
  183. stride=1,
  184. with_act=True,
  185. norm_type='sync_bn',
  186. name=None):
  187. super(LiteConv, self).__init__()
  188. self.lite_conv = nn.Sequential()
  189. conv1 = ConvNormLayer(
  190. in_channels,
  191. in_channels,
  192. filter_size=5,
  193. stride=stride,
  194. groups=in_channels,
  195. norm_type=norm_type,
  196. initializer=XavierUniform())
  197. conv2 = ConvNormLayer(
  198. in_channels,
  199. out_channels,
  200. filter_size=1,
  201. stride=stride,
  202. norm_type=norm_type,
  203. initializer=XavierUniform())
  204. conv3 = ConvNormLayer(
  205. out_channels,
  206. out_channels,
  207. filter_size=1,
  208. stride=stride,
  209. norm_type=norm_type,
  210. initializer=XavierUniform())
  211. conv4 = ConvNormLayer(
  212. out_channels,
  213. out_channels,
  214. filter_size=5,
  215. stride=stride,
  216. groups=out_channels,
  217. norm_type=norm_type,
  218. initializer=XavierUniform())
  219. conv_list = [conv1, conv2, conv3, conv4]
  220. self.lite_conv.add_sublayer('conv1', conv1)
  221. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  222. self.lite_conv.add_sublayer('conv2', conv2)
  223. if with_act:
  224. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  225. self.lite_conv.add_sublayer('conv3', conv3)
  226. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  227. self.lite_conv.add_sublayer('conv4', conv4)
  228. if with_act:
  229. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  230. def forward(self, inputs):
  231. out = self.lite_conv(inputs)
  232. return out
  233. class DropBlock(nn.Layer):
  234. def __init__(self, block_size, keep_prob, name=None, data_format='NCHW'):
  235. """
  236. DropBlock layer, see https://arxiv.org/abs/1810.12890
  237. Args:
  238. block_size (int): block size
  239. keep_prob (int): keep probability
  240. name (str): layer name
  241. data_format (str): data format, NCHW or NHWC
  242. """
  243. super(DropBlock, self).__init__()
  244. self.block_size = block_size
  245. self.keep_prob = keep_prob
  246. self.name = name
  247. self.data_format = data_format
  248. def forward(self, x):
  249. if not self.training or self.keep_prob == 1:
  250. return x
  251. else:
  252. gamma = (1. - self.keep_prob) / (self.block_size**2)
  253. if self.data_format == 'NCHW':
  254. shape = x.shape[2:]
  255. else:
  256. shape = x.shape[1:3]
  257. for s in shape:
  258. gamma *= s / (s - self.block_size + 1)
  259. matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
  260. mask_inv = F.max_pool2d(
  261. matrix,
  262. self.block_size,
  263. stride=1,
  264. padding=self.block_size // 2,
  265. data_format=self.data_format)
  266. mask = 1. - mask_inv
  267. y = x * mask * (mask.numel() / mask.sum())
  268. return y
  269. @register
  270. @serializable
  271. class AnchorGeneratorSSD(object):
  272. def __init__(
  273. self,
  274. steps=[8, 16, 32, 64, 100, 300],
  275. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  276. min_ratio=15,
  277. max_ratio=90,
  278. base_size=300,
  279. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  280. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  281. offset=0.5,
  282. flip=True,
  283. clip=False,
  284. min_max_aspect_ratios_order=False):
  285. self.steps = steps
  286. self.aspect_ratios = aspect_ratios
  287. self.min_ratio = min_ratio
  288. self.max_ratio = max_ratio
  289. self.base_size = base_size
  290. self.min_sizes = min_sizes
  291. self.max_sizes = max_sizes
  292. self.offset = offset
  293. self.flip = flip
  294. self.clip = clip
  295. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  296. if self.min_sizes == [] and self.max_sizes == []:
  297. num_layer = len(aspect_ratios)
  298. step = int(
  299. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  300. )))
  301. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  302. step):
  303. self.min_sizes.append(self.base_size * ratio / 100.)
  304. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  305. self.min_sizes = [self.base_size * .10] + self.min_sizes
  306. self.max_sizes = [self.base_size * .20] + self.max_sizes
  307. self.num_priors = []
  308. for aspect_ratio, min_size, max_size in zip(
  309. aspect_ratios, self.min_sizes, self.max_sizes):
  310. if isinstance(min_size, (list, tuple)):
  311. self.num_priors.append(
  312. len(_to_list(min_size)) + len(_to_list(max_size)))
  313. else:
  314. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  315. _to_list(min_size)) + len(_to_list(max_size)))
  316. def __call__(self, inputs, image):
  317. boxes = []
  318. for input, min_size, max_size, aspect_ratio, step in zip(
  319. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  320. self.steps):
  321. box, _ = ops.prior_box(
  322. input=input,
  323. image=image,
  324. min_sizes=_to_list(min_size),
  325. max_sizes=_to_list(max_size),
  326. aspect_ratios=aspect_ratio,
  327. flip=self.flip,
  328. clip=self.clip,
  329. steps=[step, step],
  330. offset=self.offset,
  331. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  332. boxes.append(paddle.reshape(box, [-1, 4]))
  333. return boxes
  334. @register
  335. @serializable
  336. class RCNNBox(object):
  337. __shared__ = ['num_classes', 'export_onnx']
  338. def __init__(self,
  339. prior_box_var=[10., 10., 5., 5.],
  340. code_type="decode_center_size",
  341. box_normalized=False,
  342. num_classes=80,
  343. export_onnx=False):
  344. super(RCNNBox, self).__init__()
  345. self.prior_box_var = prior_box_var
  346. self.code_type = code_type
  347. self.box_normalized = box_normalized
  348. self.num_classes = num_classes
  349. self.export_onnx = export_onnx
  350. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  351. bbox_pred = bbox_head_out[0]
  352. cls_prob = bbox_head_out[1]
  353. roi = rois[0]
  354. rois_num = rois[1]
  355. if self.export_onnx:
  356. onnx_rois_num_per_im = rois_num[0]
  357. origin_shape = paddle.expand(im_shape[0, :],
  358. [onnx_rois_num_per_im, 2])
  359. else:
  360. origin_shape_list = []
  361. if isinstance(roi, list):
  362. batch_size = len(roi)
  363. else:
  364. batch_size = paddle.slice(
  365. paddle.shape(im_shape), [0], [0], [1])
  366. # bbox_pred.shape: [N, C*4]
  367. for idx in range(batch_size):
  368. rois_num_per_im = rois_num[idx]
  369. expand_im_shape = paddle.expand(im_shape[idx, :],
  370. [rois_num_per_im, 2])
  371. origin_shape_list.append(expand_im_shape)
  372. origin_shape = paddle.concat(origin_shape_list)
  373. # bbox_pred.shape: [N, C*4]
  374. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  375. bbox = paddle.concat(roi)
  376. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  377. scores = cls_prob[:, :-1]
  378. # bbox.shape: [N, C, 4]
  379. # bbox.shape[1] must be equal to scores.shape[1]
  380. total_num = bbox.shape[0]
  381. bbox_dim = bbox.shape[-1]
  382. bbox = paddle.expand(bbox, [total_num, self.num_classes, bbox_dim])
  383. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  384. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  385. zeros = paddle.zeros_like(origin_h)
  386. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  387. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  388. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  389. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  390. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  391. bboxes = (bbox, rois_num)
  392. return bboxes, scores
  393. @register
  394. @serializable
  395. class MultiClassNMS(object):
  396. def __init__(self,
  397. score_threshold=.05,
  398. nms_top_k=-1,
  399. keep_top_k=100,
  400. nms_threshold=.5,
  401. normalized=True,
  402. nms_eta=1.0,
  403. return_index=False,
  404. return_rois_num=True,
  405. trt=False):
  406. super(MultiClassNMS, self).__init__()
  407. self.score_threshold = score_threshold
  408. self.nms_top_k = nms_top_k
  409. self.keep_top_k = keep_top_k
  410. self.nms_threshold = nms_threshold
  411. self.normalized = normalized
  412. self.nms_eta = nms_eta
  413. self.return_index = return_index
  414. self.return_rois_num = return_rois_num
  415. self.trt = trt
  416. def __call__(self, bboxes, score, background_label=-1):
  417. """
  418. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  419. [N, M, 4], N is the batch size and M
  420. is the number of bboxes
  421. 2. (List[Tensor]) bboxes and bbox_num,
  422. bboxes have shape of [M, C, 4], C
  423. is the class number and bbox_num means
  424. the number of bboxes of each batch with
  425. shape [N,]
  426. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  427. background_label (int): Ignore the background label; For example, RCNN
  428. is num_classes and YOLO is -1.
  429. """
  430. kwargs = self.__dict__.copy()
  431. if isinstance(bboxes, tuple):
  432. bboxes, bbox_num = bboxes
  433. kwargs.update({'rois_num': bbox_num})
  434. if background_label > -1:
  435. kwargs.update({'background_label': background_label})
  436. kwargs.pop('trt')
  437. # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt
  438. if self.trt and (int(paddle.version.major) == 0 or
  439. (int(paddle.version.major) >= 2 and
  440. int(paddle.version.minor) >= 3)):
  441. # TODO(wangxinxin08): tricky switch to run nms on tensorrt
  442. kwargs.update({'nms_eta': 1.1})
  443. bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
  444. mask = paddle.slice(bbox, [-1], [0], [1]) != -1
  445. bbox = paddle.masked_select(bbox, mask).reshape((-1, 6))
  446. return bbox, bbox_num, None
  447. else:
  448. return ops.multiclass_nms(bboxes, score, **kwargs)
  449. @register
  450. @serializable
  451. class MatrixNMS(object):
  452. __append_doc__ = True
  453. def __init__(self,
  454. score_threshold=.05,
  455. post_threshold=.05,
  456. nms_top_k=-1,
  457. keep_top_k=100,
  458. use_gaussian=False,
  459. gaussian_sigma=2.,
  460. normalized=False,
  461. background_label=0):
  462. super(MatrixNMS, self).__init__()
  463. self.score_threshold = score_threshold
  464. self.post_threshold = post_threshold
  465. self.nms_top_k = nms_top_k
  466. self.keep_top_k = keep_top_k
  467. self.normalized = normalized
  468. self.use_gaussian = use_gaussian
  469. self.gaussian_sigma = gaussian_sigma
  470. self.background_label = background_label
  471. def __call__(self, bbox, score, *args):
  472. return ops.matrix_nms(
  473. bboxes=bbox,
  474. scores=score,
  475. score_threshold=self.score_threshold,
  476. post_threshold=self.post_threshold,
  477. nms_top_k=self.nms_top_k,
  478. keep_top_k=self.keep_top_k,
  479. use_gaussian=self.use_gaussian,
  480. gaussian_sigma=self.gaussian_sigma,
  481. background_label=self.background_label,
  482. normalized=self.normalized)
  483. @register
  484. @serializable
  485. class YOLOBox(object):
  486. __shared__ = ['num_classes']
  487. def __init__(self,
  488. num_classes=80,
  489. conf_thresh=0.005,
  490. downsample_ratio=32,
  491. clip_bbox=True,
  492. scale_x_y=1.):
  493. self.num_classes = num_classes
  494. self.conf_thresh = conf_thresh
  495. self.downsample_ratio = downsample_ratio
  496. self.clip_bbox = clip_bbox
  497. self.scale_x_y = scale_x_y
  498. def __call__(self,
  499. yolo_head_out,
  500. anchors,
  501. im_shape,
  502. scale_factor,
  503. var_weight=None):
  504. boxes_list = []
  505. scores_list = []
  506. origin_shape = im_shape / scale_factor
  507. origin_shape = paddle.cast(origin_shape, 'int32')
  508. for i, head_out in enumerate(yolo_head_out):
  509. boxes, scores = paddle.vision.ops.yolo_box(
  510. head_out,
  511. origin_shape,
  512. anchors[i],
  513. self.num_classes,
  514. self.conf_thresh,
  515. self.downsample_ratio // 2**i,
  516. self.clip_bbox,
  517. scale_x_y=self.scale_x_y)
  518. boxes_list.append(boxes)
  519. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  520. yolo_boxes = paddle.concat(boxes_list, axis=1)
  521. yolo_scores = paddle.concat(scores_list, axis=2)
  522. return yolo_boxes, yolo_scores
  523. @register
  524. @serializable
  525. class SSDBox(object):
  526. def __init__(self,
  527. is_normalized=True,
  528. prior_box_var=[0.1, 0.1, 0.2, 0.2],
  529. use_fuse_decode=False):
  530. self.is_normalized = is_normalized
  531. self.norm_delta = float(not self.is_normalized)
  532. self.prior_box_var = prior_box_var
  533. self.use_fuse_decode = use_fuse_decode
  534. def __call__(self,
  535. preds,
  536. prior_boxes,
  537. im_shape,
  538. scale_factor,
  539. var_weight=None):
  540. boxes, scores = preds
  541. boxes = paddle.concat(boxes, axis=1)
  542. prior_boxes = paddle.concat(prior_boxes)
  543. if self.use_fuse_decode:
  544. output_boxes = ops.box_coder(
  545. prior_boxes,
  546. self.prior_box_var,
  547. boxes,
  548. code_type="decode_center_size",
  549. box_normalized=self.is_normalized)
  550. else:
  551. pb_w = prior_boxes[:, 2] - prior_boxes[:, 0] + self.norm_delta
  552. pb_h = prior_boxes[:, 3] - prior_boxes[:, 1] + self.norm_delta
  553. pb_x = prior_boxes[:, 0] + pb_w * 0.5
  554. pb_y = prior_boxes[:, 1] + pb_h * 0.5
  555. out_x = pb_x + boxes[:, :, 0] * pb_w * self.prior_box_var[0]
  556. out_y = pb_y + boxes[:, :, 1] * pb_h * self.prior_box_var[1]
  557. out_w = paddle.exp(boxes[:, :, 2] * self.prior_box_var[2]) * pb_w
  558. out_h = paddle.exp(boxes[:, :, 3] * self.prior_box_var[3]) * pb_h
  559. output_boxes = paddle.stack(
  560. [
  561. out_x - out_w / 2., out_y - out_h / 2., out_x + out_w / 2.,
  562. out_y + out_h / 2.
  563. ],
  564. axis=-1)
  565. if self.is_normalized:
  566. h = (im_shape[:, 0] / scale_factor[:, 0]).unsqueeze(-1)
  567. w = (im_shape[:, 1] / scale_factor[:, 1]).unsqueeze(-1)
  568. im_shape = paddle.stack([w, h, w, h], axis=-1)
  569. output_boxes *= im_shape
  570. else:
  571. output_boxes[..., -2:] -= 1.0
  572. output_scores = F.softmax(paddle.concat(
  573. scores, axis=1)).transpose([0, 2, 1])
  574. return output_boxes, output_scores
  575. @register
  576. @serializable
  577. class FCOSBox(object):
  578. __shared__ = ['num_classes']
  579. def __init__(self, num_classes=80):
  580. super(FCOSBox, self).__init__()
  581. self.num_classes = num_classes
  582. def _merge_hw(self, inputs, ch_type="channel_first"):
  583. """
  584. Merge h and w of the feature map into one dimension.
  585. Args:
  586. inputs (Tensor): Tensor of the input feature map
  587. ch_type (str): "channel_first" or "channel_last" style
  588. Return:
  589. new_shape (Tensor): The new shape after h and w merged
  590. """
  591. shape_ = paddle.shape(inputs)
  592. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  593. img_size = hi * wi
  594. img_size.stop_gradient = True
  595. if ch_type == "channel_first":
  596. new_shape = paddle.concat([bs, ch, img_size])
  597. elif ch_type == "channel_last":
  598. new_shape = paddle.concat([bs, img_size, ch])
  599. else:
  600. raise KeyError("Wrong ch_type %s" % ch_type)
  601. new_shape.stop_gradient = True
  602. return new_shape
  603. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  604. scale_factor):
  605. """
  606. Postprocess each layer of the output with corresponding locations.
  607. Args:
  608. locations (Tensor): anchor points for current layer, [H*W, 2]
  609. box_cls (Tensor): categories prediction, [N, C, H, W],
  610. C is the number of classes
  611. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  612. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  613. scale_factor (Tensor): [h_scale, w_scale] for input images
  614. Return:
  615. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  616. C is the number of classes and M is the number of anchor points
  617. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  618. last dimension is [x1, y1, x2, y2]
  619. """
  620. act_shape_cls = self._merge_hw(box_cls)
  621. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  622. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  623. act_shape_reg = self._merge_hw(box_reg)
  624. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  625. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  626. box_reg_decoding = paddle.stack(
  627. [
  628. locations[:, 0] - box_reg_ch_last[:, :, 0],
  629. locations[:, 1] - box_reg_ch_last[:, :, 1],
  630. locations[:, 0] + box_reg_ch_last[:, :, 2],
  631. locations[:, 1] + box_reg_ch_last[:, :, 3]
  632. ],
  633. axis=1)
  634. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  635. act_shape_ctn = self._merge_hw(box_ctn)
  636. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  637. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  638. # recover the location to original image
  639. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  640. im_scale = paddle.expand(im_scale, [box_reg_decoding.shape[0], 4])
  641. im_scale = paddle.reshape(im_scale, [box_reg_decoding.shape[0], -1, 4])
  642. box_reg_decoding = box_reg_decoding / im_scale
  643. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  644. return box_cls_ch_last, box_reg_decoding
  645. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  646. scale_factor):
  647. pred_boxes_ = []
  648. pred_scores_ = []
  649. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  650. centerness):
  651. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  652. pts, cls, box, ctn, scale_factor)
  653. pred_boxes_.append(pred_boxes_lvl)
  654. pred_scores_.append(pred_scores_lvl)
  655. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  656. pred_scores = paddle.concat(pred_scores_, axis=2)
  657. return pred_boxes, pred_scores
  658. @register
  659. class TTFBox(object):
  660. __shared__ = ['down_ratio']
  661. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  662. super(TTFBox, self).__init__()
  663. self.max_per_img = max_per_img
  664. self.score_thresh = score_thresh
  665. self.down_ratio = down_ratio
  666. def _simple_nms(self, heat, kernel=3):
  667. """
  668. Use maxpool to filter the max score, get local peaks.
  669. """
  670. pad = (kernel - 1) // 2
  671. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  672. keep = paddle.cast(hmax == heat, 'float32')
  673. return heat * keep
  674. def _topk(self, scores):
  675. """
  676. Select top k scores and decode to get xy coordinates.
  677. """
  678. k = self.max_per_img
  679. shape_fm = paddle.shape(scores)
  680. shape_fm.stop_gradient = True
  681. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  682. # batch size is 1
  683. scores_r = paddle.reshape(scores, [cat, -1])
  684. topk_scores, topk_inds = paddle.topk(scores_r, k)
  685. topk_scores, topk_inds = paddle.topk(scores_r, k)
  686. topk_ys = topk_inds // width
  687. topk_xs = topk_inds % width
  688. topk_score_r = paddle.reshape(topk_scores, [-1])
  689. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  690. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  691. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  692. topk_inds = paddle.reshape(topk_inds, [-1])
  693. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  694. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  695. topk_inds = paddle.gather(topk_inds, topk_ind)
  696. topk_ys = paddle.gather(topk_ys, topk_ind)
  697. topk_xs = paddle.gather(topk_xs, topk_ind)
  698. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  699. def _decode(self, hm, wh, im_shape, scale_factor):
  700. heatmap = F.sigmoid(hm)
  701. heat = self._simple_nms(heatmap)
  702. scores, inds, clses, ys, xs = self._topk(heat)
  703. ys = paddle.cast(ys, 'float32') * self.down_ratio
  704. xs = paddle.cast(xs, 'float32') * self.down_ratio
  705. scores = paddle.tensor.unsqueeze(scores, [1])
  706. clses = paddle.tensor.unsqueeze(clses, [1])
  707. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  708. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  709. wh = paddle.gather(wh, inds)
  710. x1 = xs - wh[:, 0:1]
  711. y1 = ys - wh[:, 1:2]
  712. x2 = xs + wh[:, 2:3]
  713. y2 = ys + wh[:, 3:4]
  714. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  715. scale_y = scale_factor[:, 0:1]
  716. scale_x = scale_factor[:, 1:2]
  717. scale_expand = paddle.concat(
  718. [scale_x, scale_y, scale_x, scale_y], axis=1)
  719. boxes_shape = paddle.shape(bboxes)
  720. boxes_shape.stop_gradient = True
  721. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  722. bboxes = paddle.divide(bboxes, scale_expand)
  723. results = paddle.concat([clses, scores, bboxes], axis=1)
  724. # hack: append result with cls=-1 and score=1. to avoid all scores
  725. # are less than score_thresh which may cause error in gather.
  726. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  727. fill_r = paddle.cast(fill_r, results.dtype)
  728. results = paddle.concat([results, fill_r])
  729. scores = results[:, 1]
  730. valid_ind = paddle.nonzero(scores > self.score_thresh)
  731. results = paddle.gather(results, valid_ind)
  732. return results, paddle.shape(results)[0:1]
  733. def __call__(self, hm, wh, im_shape, scale_factor):
  734. results = []
  735. results_num = []
  736. for i in range(scale_factor.shape[0]):
  737. result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
  738. im_shape[i:i + 1, ],
  739. scale_factor[i:i + 1, ])
  740. results.append(result)
  741. results_num.append(num)
  742. results = paddle.concat(results, axis=0)
  743. results_num = paddle.concat(results_num, axis=0)
  744. return results, results_num
  745. @register
  746. @serializable
  747. class JDEBox(object):
  748. __shared__ = ['num_classes']
  749. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  750. self.num_classes = num_classes
  751. self.conf_thresh = conf_thresh
  752. self.downsample_ratio = downsample_ratio
  753. def generate_anchor(self, nGh, nGw, anchor_wh):
  754. nA = len(anchor_wh)
  755. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  756. mesh = paddle.stack(
  757. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  758. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  759. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  760. int(nGh), axis=-2).repeat(
  761. int(nGw), axis=-1)
  762. anchor_offset_mesh = paddle.to_tensor(
  763. anchor_offset_mesh.astype(np.float32))
  764. # nA x 2 x nGh x nGw
  765. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  766. anchor_mesh = paddle.transpose(anchor_mesh,
  767. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  768. return anchor_mesh
  769. def decode_delta(self, delta, fg_anchor_list):
  770. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  771. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  772. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  773. gx = pw * dx + px
  774. gy = ph * dy + py
  775. gw = pw * paddle.exp(dw)
  776. gh = ph * paddle.exp(dh)
  777. gx1 = gx - gw * 0.5
  778. gy1 = gy - gh * 0.5
  779. gx2 = gx + gw * 0.5
  780. gy2 = gy + gh * 0.5
  781. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  782. def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
  783. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
  784. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  785. pred_list = self.decode_delta(
  786. paddle.reshape(
  787. delta_map, shape=[-1, 4]),
  788. paddle.reshape(
  789. anchor_mesh, shape=[-1, 4]))
  790. pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
  791. return pred_map
  792. def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
  793. boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
  794. nGh, nGw = boxes_shape[-2], boxes_shape[-1]
  795. nB = 1 # TODO: only support bs=1 now
  796. boxes_list, scores_list = [], []
  797. for idx in range(nB):
  798. p = paddle.reshape(
  799. head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
  800. p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
  801. delta_map = p[:, :, :, :4]
  802. boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
  803. # [nA * nGh * nGw, 4]
  804. boxes_list.append(boxes * stride)
  805. p_conf = paddle.transpose(
  806. p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
  807. p_conf = F.softmax(
  808. p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
  809. scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
  810. scores_list.append(scores)
  811. boxes_results = paddle.stack(boxes_list)
  812. scores_results = paddle.stack(scores_list)
  813. return boxes_results, scores_results
  814. def __call__(self, yolo_head_out, anchors):
  815. bbox_pred_list = []
  816. for i, head_out in enumerate(yolo_head_out):
  817. stride = self.downsample_ratio // 2**i
  818. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  819. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  820. nA = len(anc_w)
  821. boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
  822. anchor_vec)
  823. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  824. yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
  825. boxes_idx_over_conf_thr = paddle.nonzero(
  826. yolo_boxes_scores[:, :, -1] > self.conf_thresh)
  827. boxes_idx_over_conf_thr.stop_gradient = True
  828. return boxes_idx_over_conf_thr, yolo_boxes_scores
  829. @register
  830. @serializable
  831. class MaskMatrixNMS(object):
  832. """
  833. Matrix NMS for multi-class masks.
  834. Args:
  835. update_threshold (float): Updated threshold of categroy score in second time.
  836. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  837. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  838. kernel (str): 'linear' or 'gaussian'.
  839. sigma (float): std in gaussian method.
  840. Input:
  841. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  842. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  843. cate_labels (Variable): shape (n), mask labels in descending order
  844. cate_scores (Variable): shape (n), mask scores in descending order
  845. sum_masks (Variable): a float tensor of the sum of seg_masks
  846. Returns:
  847. Variable: cate_scores, tensors of shape (n)
  848. """
  849. def __init__(self,
  850. update_threshold=0.05,
  851. pre_nms_top_n=500,
  852. post_nms_top_n=100,
  853. kernel='gaussian',
  854. sigma=2.0):
  855. super(MaskMatrixNMS, self).__init__()
  856. self.update_threshold = update_threshold
  857. self.pre_nms_top_n = pre_nms_top_n
  858. self.post_nms_top_n = post_nms_top_n
  859. self.kernel = kernel
  860. self.sigma = sigma
  861. def _sort_score(self, scores, top_num):
  862. if paddle.shape(scores)[0] > top_num:
  863. return paddle.topk(scores, top_num)[1]
  864. else:
  865. return paddle.argsort(scores, descending=True)
  866. def __call__(self,
  867. seg_preds,
  868. seg_masks,
  869. cate_labels,
  870. cate_scores,
  871. sum_masks=None):
  872. # sort and keep top nms_pre
  873. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  874. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  875. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  876. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  877. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  878. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  879. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  880. # inter.
  881. inter_matrix = paddle.mm(seg_masks,
  882. paddle.transpose(seg_masks, [1, 0]))
  883. n_samples = paddle.shape(cate_labels)
  884. # union.
  885. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  886. # iou.
  887. iou_matrix = (inter_matrix / (
  888. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)
  889. )
  890. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  891. # label_specific matrix.
  892. cate_labels_x = paddle.expand(
  893. cate_labels, shape=[n_samples, n_samples])
  894. label_matrix = paddle.cast(
  895. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  896. 'float32')
  897. label_matrix = paddle.triu(label_matrix, diagonal=1)
  898. # IoU compensation
  899. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  900. compensate_iou = paddle.expand(
  901. compensate_iou, shape=[n_samples, n_samples])
  902. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  903. # IoU decay
  904. decay_iou = iou_matrix * label_matrix
  905. # matrix nms
  906. if self.kernel == 'gaussian':
  907. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  908. compensate_matrix = paddle.exp(-1 * self.sigma *
  909. (compensate_iou**2))
  910. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  911. axis=0)
  912. elif self.kernel == 'linear':
  913. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  914. decay_coefficient = paddle.min(decay_matrix, axis=0)
  915. else:
  916. raise NotImplementedError
  917. # update the score.
  918. cate_scores = cate_scores * decay_coefficient
  919. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  920. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  921. y)
  922. keep = paddle.nonzero(keep)
  923. keep = paddle.squeeze(keep, axis=[1])
  924. # Prevent empty and increase fake data
  925. keep = paddle.concat(
  926. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  927. seg_preds = paddle.gather(seg_preds, index=keep)
  928. cate_scores = paddle.gather(cate_scores, index=keep)
  929. cate_labels = paddle.gather(cate_labels, index=keep)
  930. # sort and keep top_k
  931. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  932. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  933. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  934. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  935. return seg_preds, cate_scores, cate_labels
  936. def Conv2d(in_channels,
  937. out_channels,
  938. kernel_size,
  939. stride=1,
  940. padding=0,
  941. dilation=1,
  942. groups=1,
  943. bias=True,
  944. weight_init=Normal(std=0.001),
  945. bias_init=Constant(0.)):
  946. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  947. if bias:
  948. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  949. else:
  950. bias_attr = False
  951. conv = nn.Conv2D(
  952. in_channels,
  953. out_channels,
  954. kernel_size,
  955. stride,
  956. padding,
  957. dilation,
  958. groups,
  959. weight_attr=weight_attr,
  960. bias_attr=bias_attr)
  961. return conv
  962. def ConvTranspose2d(in_channels,
  963. out_channels,
  964. kernel_size,
  965. stride=1,
  966. padding=0,
  967. output_padding=0,
  968. groups=1,
  969. bias=True,
  970. dilation=1,
  971. weight_init=Normal(std=0.001),
  972. bias_init=Constant(0.)):
  973. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  974. if bias:
  975. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  976. else:
  977. bias_attr = False
  978. conv = nn.Conv2DTranspose(
  979. in_channels,
  980. out_channels,
  981. kernel_size,
  982. stride,
  983. padding,
  984. output_padding,
  985. dilation,
  986. groups,
  987. weight_attr=weight_attr,
  988. bias_attr=bias_attr)
  989. return conv
  990. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  991. if not affine:
  992. weight_attr = False
  993. bias_attr = False
  994. else:
  995. weight_attr = None
  996. bias_attr = None
  997. batchnorm = nn.BatchNorm2D(
  998. num_features,
  999. momentum,
  1000. eps,
  1001. weight_attr=weight_attr,
  1002. bias_attr=bias_attr)
  1003. return batchnorm
  1004. def ReLU():
  1005. return nn.ReLU()
  1006. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  1007. return nn.Upsample(None, scale_factor, mode, align_corners)
  1008. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1009. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1010. class Concat(nn.Layer):
  1011. def __init__(self, dim=0):
  1012. super(Concat, self).__init__()
  1013. self.dim = dim
  1014. def forward(self, inputs):
  1015. return paddle.concat(inputs, axis=self.dim)
  1016. def extra_repr(self):
  1017. return 'dim={}'.format(self.dim)
  1018. def _convert_attention_mask(attn_mask, dtype):
  1019. """
  1020. Convert the attention mask to the target dtype we expect.
  1021. Parameters:
  1022. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1023. to prevents attention to some unwanted positions, usually the
  1024. paddings or the subsequent positions. It is a tensor with shape
  1025. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1026. When the data type is bool, the unwanted positions have `False`
  1027. values and the others have `True` values. When the data type is
  1028. int, the unwanted positions have 0 values and the others have 1
  1029. values. When the data type is float, the unwanted positions have
  1030. `-INF` values and the others have 0 values. It can be None when
  1031. nothing wanted or needed to be prevented attention to. Default None.
  1032. dtype (VarType): The target type of `attn_mask` we expect.
  1033. Returns:
  1034. Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
  1035. """
  1036. return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
  1037. class MultiHeadAttention(nn.Layer):
  1038. """
  1039. Attention mapps queries and a set of key-value pairs to outputs, and
  1040. Multi-Head Attention performs multiple parallel attention to jointly attending
  1041. to information from different representation subspaces.
  1042. Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
  1043. for more details.
  1044. Parameters:
  1045. embed_dim (int): The expected feature size in the input and output.
  1046. num_heads (int): The number of heads in multi-head attention.
  1047. dropout (float, optional): The dropout probability used on attention
  1048. weights to drop some attention targets. 0 for no dropout. Default 0
  1049. kdim (int, optional): The feature size in key. If None, assumed equal to
  1050. `embed_dim`. Default None.
  1051. vdim (int, optional): The feature size in value. If None, assumed equal to
  1052. `embed_dim`. Default None.
  1053. need_weights (bool, optional): Indicate whether to return the attention
  1054. weights. Default False.
  1055. Examples:
  1056. .. code-block:: python
  1057. import paddle
  1058. # encoder input: [batch_size, sequence_length, d_model]
  1059. query = paddle.rand((2, 4, 128))
  1060. # self attention mask: [batch_size, num_heads, query_len, query_len]
  1061. attn_mask = paddle.rand((2, 2, 4, 4))
  1062. multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
  1063. output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
  1064. """
  1065. def __init__(self,
  1066. embed_dim,
  1067. num_heads,
  1068. dropout=0.,
  1069. kdim=None,
  1070. vdim=None,
  1071. need_weights=False):
  1072. super(MultiHeadAttention, self).__init__()
  1073. self.embed_dim = embed_dim
  1074. self.kdim = kdim if kdim is not None else embed_dim
  1075. self.vdim = vdim if vdim is not None else embed_dim
  1076. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1077. self.num_heads = num_heads
  1078. self.dropout = dropout
  1079. self.need_weights = need_weights
  1080. self.head_dim = embed_dim // num_heads
  1081. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  1082. if self._qkv_same_embed_dim:
  1083. self.in_proj_weight = self.create_parameter(
  1084. shape=[embed_dim, 3 * embed_dim],
  1085. attr=None,
  1086. dtype=self._dtype,
  1087. is_bias=False)
  1088. self.in_proj_bias = self.create_parameter(
  1089. shape=[3 * embed_dim],
  1090. attr=None,
  1091. dtype=self._dtype,
  1092. is_bias=True)
  1093. else:
  1094. self.q_proj = nn.Linear(embed_dim, embed_dim)
  1095. self.k_proj = nn.Linear(self.kdim, embed_dim)
  1096. self.v_proj = nn.Linear(self.vdim, embed_dim)
  1097. self.out_proj = nn.Linear(embed_dim, embed_dim)
  1098. self._type_list = ('q_proj', 'k_proj', 'v_proj')
  1099. self._reset_parameters()
  1100. def _reset_parameters(self):
  1101. for p in self.parameters():
  1102. if p.dim() > 1:
  1103. xavier_uniform_(p)
  1104. else:
  1105. constant_(p)
  1106. def compute_qkv(self, tensor, index):
  1107. if self._qkv_same_embed_dim:
  1108. tensor = F.linear(
  1109. x=tensor,
  1110. weight=self.in_proj_weight[:, index * self.embed_dim:(
  1111. index + 1) * self.embed_dim],
  1112. bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
  1113. self.embed_dim]
  1114. if self.in_proj_bias is not None else None)
  1115. else:
  1116. tensor = getattr(self, self._type_list[index])(tensor)
  1117. tensor = tensor.reshape(
  1118. [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  1119. return tensor
  1120. def forward(self, query, key=None, value=None, attn_mask=None):
  1121. r"""
  1122. Applies multi-head attention to map queries and a set of key-value pairs
  1123. to outputs.
  1124. Parameters:
  1125. query (Tensor): The queries for multi-head attention. It is a
  1126. tensor with shape `[batch_size, query_length, embed_dim]`. The
  1127. data type should be float32 or float64.
  1128. key (Tensor, optional): The keys for multi-head attention. It is
  1129. a tensor with shape `[batch_size, key_length, kdim]`. The
  1130. data type should be float32 or float64. If None, use `query` as
  1131. `key`. Default None.
  1132. value (Tensor, optional): The values for multi-head attention. It
  1133. is a tensor with shape `[batch_size, value_length, vdim]`.
  1134. The data type should be float32 or float64. If None, use `query` as
  1135. `value`. Default None.
  1136. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1137. to prevents attention to some unwanted positions, usually the
  1138. paddings or the subsequent positions. It is a tensor with shape
  1139. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1140. When the data type is bool, the unwanted positions have `False`
  1141. values and the others have `True` values. When the data type is
  1142. int, the unwanted positions have 0 values and the others have 1
  1143. values. When the data type is float, the unwanted positions have
  1144. `-INF` values and the others have 0 values. It can be None when
  1145. nothing wanted or needed to be prevented attention to. Default None.
  1146. Returns:
  1147. Tensor|tuple: It is a tensor that has the same shape and data type \
  1148. as `query`, representing attention output. Or a tuple if \
  1149. `need_weights` is True or `cache` is not None. If `need_weights` \
  1150. is True, except for attention output, the tuple also includes \
  1151. the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
  1152. If `cache` is not None, the tuple then includes the new cache \
  1153. having the same type as `cache`, and if it is `StaticCache`, it \
  1154. is same as the input `cache`, if it is `Cache`, the new cache \
  1155. reserves tensors concatanating raw tensors with intermediate \
  1156. results of current query.
  1157. """
  1158. key = query if key is None else key
  1159. value = query if value is None else value
  1160. # compute q ,k ,v
  1161. q, k, v = (self.compute_qkv(t, i)
  1162. for i, t in enumerate([query, key, value]))
  1163. # scale dot product attention
  1164. product = paddle.matmul(x=q, y=k, transpose_y=True)
  1165. scaling = float(self.head_dim)**-0.5
  1166. product = product * scaling
  1167. if attn_mask is not None:
  1168. # Support bool or int mask
  1169. attn_mask = _convert_attention_mask(attn_mask, product.dtype)
  1170. product = product + attn_mask
  1171. weights = F.softmax(product)
  1172. if self.dropout:
  1173. weights = F.dropout(
  1174. weights,
  1175. self.dropout,
  1176. training=self.training,
  1177. mode="upscale_in_train")
  1178. out = paddle.matmul(weights, v)
  1179. # combine heads
  1180. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  1181. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  1182. # project to output
  1183. out = self.out_proj(out)
  1184. outs = [out]
  1185. if self.need_weights:
  1186. outs.append(weights)
  1187. return out if len(outs) == 1 else tuple(outs)
  1188. @register
  1189. class ConvMixer(nn.Layer):
  1190. def __init__(
  1191. self,
  1192. dim,
  1193. depth,
  1194. kernel_size=3, ):
  1195. super().__init__()
  1196. self.dim = dim
  1197. self.depth = depth
  1198. self.kernel_size = kernel_size
  1199. self.mixer = self.conv_mixer(dim, depth, kernel_size)
  1200. def forward(self, x):
  1201. return self.mixer(x)
  1202. @staticmethod
  1203. def conv_mixer(
  1204. dim,
  1205. depth,
  1206. kernel_size, ):
  1207. Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
  1208. Residual = type('Residual', (Seq, ),
  1209. {'forward': lambda self, x: self[0](x) + x})
  1210. return Seq(* [
  1211. Seq(Residual(
  1212. ActBn(
  1213. nn.Conv2D(
  1214. dim, dim, kernel_size, groups=dim, padding="same"))),
  1215. ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth)
  1216. ])