layers.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. from paddle.nn import Conv2D, BatchNorm2D, GroupNorm
  23. import paddle.nn.functional as F
  24. from paddle.nn.initializer import Normal, Constant, XavierUniform
  25. from paddle.regularizer import L2Decay
  26. from paddlex.ppdet.core.workspace import register, serializable
  27. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  28. from . import ops
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn']
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. skip_quant=skip_quant)
  152. norm_lr = 0. if freeze_norm else 1.
  153. param_attr = ParamAttr(
  154. learning_rate=norm_lr,
  155. regularizer=L2Decay(norm_decay)
  156. if norm_decay is not None else None)
  157. bias_attr = ParamAttr(
  158. learning_rate=norm_lr,
  159. regularizer=L2Decay(norm_decay)
  160. if norm_decay is not None else None)
  161. if norm_type == 'bn':
  162. self.norm = nn.BatchNorm2D(
  163. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  164. elif norm_type == 'sync_bn':
  165. self.norm = nn.SyncBatchNorm(
  166. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  167. elif norm_type == 'gn':
  168. self.norm = nn.GroupNorm(
  169. num_groups=norm_groups,
  170. num_channels=ch_out,
  171. weight_attr=param_attr,
  172. bias_attr=bias_attr)
  173. def forward(self, inputs):
  174. out = self.conv(inputs)
  175. out = self.norm(out)
  176. return out
  177. class LiteConv(nn.Layer):
  178. def __init__(self,
  179. in_channels,
  180. out_channels,
  181. stride=1,
  182. with_act=True,
  183. norm_type='sync_bn',
  184. name=None):
  185. super(LiteConv, self).__init__()
  186. self.lite_conv = nn.Sequential()
  187. conv1 = ConvNormLayer(
  188. in_channels,
  189. in_channels,
  190. filter_size=5,
  191. stride=stride,
  192. groups=in_channels,
  193. norm_type=norm_type,
  194. initializer=XavierUniform())
  195. conv2 = ConvNormLayer(
  196. in_channels,
  197. out_channels,
  198. filter_size=1,
  199. stride=stride,
  200. norm_type=norm_type,
  201. initializer=XavierUniform())
  202. conv3 = ConvNormLayer(
  203. out_channels,
  204. out_channels,
  205. filter_size=1,
  206. stride=stride,
  207. norm_type=norm_type,
  208. initializer=XavierUniform())
  209. conv4 = ConvNormLayer(
  210. out_channels,
  211. out_channels,
  212. filter_size=5,
  213. stride=stride,
  214. groups=out_channels,
  215. norm_type=norm_type,
  216. initializer=XavierUniform())
  217. conv_list = [conv1, conv2, conv3, conv4]
  218. self.lite_conv.add_sublayer('conv1', conv1)
  219. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  220. self.lite_conv.add_sublayer('conv2', conv2)
  221. if with_act:
  222. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  223. self.lite_conv.add_sublayer('conv3', conv3)
  224. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  225. self.lite_conv.add_sublayer('conv4', conv4)
  226. if with_act:
  227. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  228. def forward(self, inputs):
  229. out = self.lite_conv(inputs)
  230. return out
  231. @register
  232. @serializable
  233. class AnchorGeneratorSSD(object):
  234. def __init__(
  235. self,
  236. steps=[8, 16, 32, 64, 100, 300],
  237. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  238. min_ratio=15,
  239. max_ratio=90,
  240. base_size=300,
  241. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  242. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  243. offset=0.5,
  244. flip=True,
  245. clip=False,
  246. min_max_aspect_ratios_order=False):
  247. self.steps = steps
  248. self.aspect_ratios = aspect_ratios
  249. self.min_ratio = min_ratio
  250. self.max_ratio = max_ratio
  251. self.base_size = base_size
  252. self.min_sizes = min_sizes
  253. self.max_sizes = max_sizes
  254. self.offset = offset
  255. self.flip = flip
  256. self.clip = clip
  257. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  258. if self.min_sizes == [] and self.max_sizes == []:
  259. num_layer = len(aspect_ratios)
  260. step = int(
  261. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  262. )))
  263. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  264. step):
  265. self.min_sizes.append(self.base_size * ratio / 100.)
  266. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  267. self.min_sizes = [self.base_size * .10] + self.min_sizes
  268. self.max_sizes = [self.base_size * .20] + self.max_sizes
  269. self.num_priors = []
  270. for aspect_ratio, min_size, max_size in zip(
  271. aspect_ratios, self.min_sizes, self.max_sizes):
  272. if isinstance(min_size, (list, tuple)):
  273. self.num_priors.append(
  274. len(_to_list(min_size)) + len(_to_list(max_size)))
  275. else:
  276. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  277. _to_list(min_size)) + len(_to_list(max_size)))
  278. def __call__(self, inputs, image):
  279. boxes = []
  280. for input, min_size, max_size, aspect_ratio, step in zip(
  281. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  282. self.steps):
  283. box, _ = ops.prior_box(
  284. input=input,
  285. image=image,
  286. min_sizes=_to_list(min_size),
  287. max_sizes=_to_list(max_size),
  288. aspect_ratios=aspect_ratio,
  289. flip=self.flip,
  290. clip=self.clip,
  291. steps=[step, step],
  292. offset=self.offset,
  293. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  294. boxes.append(paddle.reshape(box, [-1, 4]))
  295. return boxes
  296. @register
  297. @serializable
  298. class RCNNBox(object):
  299. __shared__ = ['num_classes']
  300. def __init__(self,
  301. prior_box_var=[10., 10., 5., 5.],
  302. code_type="decode_center_size",
  303. box_normalized=False,
  304. num_classes=80):
  305. super(RCNNBox, self).__init__()
  306. self.prior_box_var = prior_box_var
  307. self.code_type = code_type
  308. self.box_normalized = box_normalized
  309. self.num_classes = num_classes
  310. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  311. bbox_pred = bbox_head_out[0]
  312. cls_prob = bbox_head_out[1]
  313. roi = rois[0]
  314. rois_num = rois[1]
  315. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  316. scale_list = []
  317. origin_shape_list = []
  318. for idx, roi_per_im in enumerate(roi):
  319. rois_num_per_im = rois_num[idx]
  320. expand_im_shape = paddle.expand(im_shape[idx, :],
  321. [rois_num_per_im, 2])
  322. origin_shape_list.append(expand_im_shape)
  323. origin_shape = paddle.concat(origin_shape_list)
  324. # bbox_pred.shape: [N, C*4]
  325. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  326. bbox = paddle.concat(roi)
  327. if bbox.shape[0] == 0:
  328. bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
  329. else:
  330. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  331. scores = cls_prob[:, :-1]
  332. # bbox.shape: [N, C, 4]
  333. # bbox.shape[1] must be equal to scores.shape[1]
  334. bbox_num_class = bbox.shape[1]
  335. if bbox_num_class == 1:
  336. bbox = paddle.tile(bbox, [1, self.num_classes, 1])
  337. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  338. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  339. zeros = paddle.zeros_like(origin_h)
  340. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  341. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  342. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  343. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  344. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  345. bboxes = (bbox, rois_num)
  346. return bboxes, scores
  347. @register
  348. @serializable
  349. class MultiClassNMS(object):
  350. def __init__(self,
  351. score_threshold=.05,
  352. nms_top_k=-1,
  353. keep_top_k=100,
  354. nms_threshold=.5,
  355. normalized=True,
  356. nms_eta=1.0,
  357. return_index=False,
  358. return_rois_num=True):
  359. super(MultiClassNMS, self).__init__()
  360. self.score_threshold = score_threshold
  361. self.nms_top_k = nms_top_k
  362. self.keep_top_k = keep_top_k
  363. self.nms_threshold = nms_threshold
  364. self.normalized = normalized
  365. self.nms_eta = nms_eta
  366. self.return_index = return_index
  367. self.return_rois_num = return_rois_num
  368. def __call__(self, bboxes, score, background_label=-1):
  369. """
  370. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  371. [N, M, 4], N is the batch size and M
  372. is the number of bboxes
  373. 2. (List[Tensor]) bboxes and bbox_num,
  374. bboxes have shape of [M, C, 4], C
  375. is the class number and bbox_num means
  376. the number of bboxes of each batch with
  377. shape [N,]
  378. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  379. background_label (int): Ignore the background label; For example, RCNN
  380. is num_classes and YOLO is -1.
  381. """
  382. kwargs = self.__dict__.copy()
  383. if isinstance(bboxes, tuple):
  384. bboxes, bbox_num = bboxes
  385. kwargs.update({'rois_num': bbox_num})
  386. if background_label > -1:
  387. kwargs.update({'background_label': background_label})
  388. return ops.multiclass_nms(bboxes, score, **kwargs)
  389. @register
  390. @serializable
  391. class MatrixNMS(object):
  392. __append_doc__ = True
  393. def __init__(self,
  394. score_threshold=.05,
  395. post_threshold=.05,
  396. nms_top_k=-1,
  397. keep_top_k=100,
  398. use_gaussian=False,
  399. gaussian_sigma=2.,
  400. normalized=False,
  401. background_label=0):
  402. super(MatrixNMS, self).__init__()
  403. self.score_threshold = score_threshold
  404. self.post_threshold = post_threshold
  405. self.nms_top_k = nms_top_k
  406. self.keep_top_k = keep_top_k
  407. self.normalized = normalized
  408. self.use_gaussian = use_gaussian
  409. self.gaussian_sigma = gaussian_sigma
  410. self.background_label = background_label
  411. def __call__(self, bbox, score, *args):
  412. return ops.matrix_nms(
  413. bboxes=bbox,
  414. scores=score,
  415. score_threshold=self.score_threshold,
  416. post_threshold=self.post_threshold,
  417. nms_top_k=self.nms_top_k,
  418. keep_top_k=self.keep_top_k,
  419. use_gaussian=self.use_gaussian,
  420. gaussian_sigma=self.gaussian_sigma,
  421. background_label=self.background_label,
  422. normalized=self.normalized)
  423. @register
  424. @serializable
  425. class YOLOBox(object):
  426. __shared__ = ['num_classes']
  427. def __init__(self,
  428. num_classes=80,
  429. conf_thresh=0.005,
  430. downsample_ratio=32,
  431. clip_bbox=True,
  432. scale_x_y=1.):
  433. self.num_classes = num_classes
  434. self.conf_thresh = conf_thresh
  435. self.downsample_ratio = downsample_ratio
  436. self.clip_bbox = clip_bbox
  437. self.scale_x_y = scale_x_y
  438. def __call__(self,
  439. yolo_head_out,
  440. anchors,
  441. im_shape,
  442. scale_factor,
  443. var_weight=None):
  444. boxes_list = []
  445. scores_list = []
  446. origin_shape = im_shape / scale_factor
  447. origin_shape = paddle.cast(origin_shape, 'int32')
  448. for i, head_out in enumerate(yolo_head_out):
  449. boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
  450. self.num_classes, self.conf_thresh,
  451. self.downsample_ratio // 2**i,
  452. self.clip_bbox, self.scale_x_y)
  453. boxes_list.append(boxes)
  454. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  455. yolo_boxes = paddle.concat(boxes_list, axis=1)
  456. yolo_scores = paddle.concat(scores_list, axis=2)
  457. return yolo_boxes, yolo_scores
  458. @register
  459. @serializable
  460. class SSDBox(object):
  461. def __init__(self, is_normalized=True):
  462. self.is_normalized = is_normalized
  463. self.norm_delta = float(not self.is_normalized)
  464. def __call__(self,
  465. preds,
  466. prior_boxes,
  467. im_shape,
  468. scale_factor,
  469. var_weight=None):
  470. boxes, scores = preds
  471. outputs = []
  472. for box, score, prior_box in zip(boxes, scores, prior_boxes):
  473. pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
  474. pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
  475. pb_x = prior_box[:, 0] + pb_w * 0.5
  476. pb_y = prior_box[:, 1] + pb_h * 0.5
  477. out_x = pb_x + box[:, :, 0] * pb_w * 0.1
  478. out_y = pb_y + box[:, :, 1] * pb_h * 0.1
  479. out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
  480. out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
  481. if self.is_normalized:
  482. h = paddle.unsqueeze(
  483. im_shape[:, 0] / scale_factor[:, 0], axis=-1)
  484. w = paddle.unsqueeze(
  485. im_shape[:, 1] / scale_factor[:, 1], axis=-1)
  486. output = paddle.stack(
  487. [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
  488. (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
  489. axis=-1)
  490. else:
  491. output = paddle.stack(
  492. [
  493. out_x - out_w / 2., out_y - out_h / 2.,
  494. out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
  495. ],
  496. axis=-1)
  497. outputs.append(output)
  498. boxes = paddle.concat(outputs, axis=1)
  499. scores = F.softmax(paddle.concat(scores, axis=1))
  500. scores = paddle.transpose(scores, [0, 2, 1])
  501. return boxes, scores
  502. @register
  503. @serializable
  504. class AnchorGrid(object):
  505. """Generate anchor grid
  506. Args:
  507. image_size (int or list): input image size, may be a single integer or
  508. list of [h, w]. Default: 512
  509. min_level (int): min level of the feature pyramid. Default: 3
  510. max_level (int): max level of the feature pyramid. Default: 7
  511. anchor_base_scale: base anchor scale. Default: 4
  512. num_scales: number of anchor scales. Default: 3
  513. aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
  514. """
  515. def __init__(self,
  516. image_size=512,
  517. min_level=3,
  518. max_level=7,
  519. anchor_base_scale=4,
  520. num_scales=3,
  521. aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
  522. super(AnchorGrid, self).__init__()
  523. if isinstance(image_size, Integral):
  524. self.image_size = [image_size, image_size]
  525. else:
  526. self.image_size = image_size
  527. for dim in self.image_size:
  528. assert dim % 2 ** max_level == 0, \
  529. "image size should be multiple of the max level stride"
  530. self.min_level = min_level
  531. self.max_level = max_level
  532. self.anchor_base_scale = anchor_base_scale
  533. self.num_scales = num_scales
  534. self.aspect_ratios = aspect_ratios
  535. @property
  536. def base_cell(self):
  537. if not hasattr(self, '_base_cell'):
  538. self._base_cell = self.make_cell()
  539. return self._base_cell
  540. def make_cell(self):
  541. scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
  542. scales = np.array(scales)
  543. ratios = np.array(self.aspect_ratios)
  544. ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
  545. hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
  546. anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
  547. return anchors
  548. def make_grid(self, stride):
  549. cell = self.base_cell * stride * self.anchor_base_scale
  550. x_steps = np.arange(stride // 2, self.image_size[1], stride)
  551. y_steps = np.arange(stride // 2, self.image_size[0], stride)
  552. offset_x, offset_y = np.meshgrid(x_steps, y_steps)
  553. offset_x = offset_x.flatten()
  554. offset_y = offset_y.flatten()
  555. offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
  556. offsets = offsets[:, np.newaxis, :]
  557. return (cell + offsets).reshape(-1, 4)
  558. def generate(self):
  559. return [
  560. self.make_grid(2**l)
  561. for l in range(self.min_level, self.max_level + 1)
  562. ]
  563. def __call__(self):
  564. if not hasattr(self, '_anchor_vars'):
  565. anchor_vars = []
  566. helper = LayerHelper('anchor_grid')
  567. for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
  568. stride = 2**l
  569. anchors = self.make_grid(stride)
  570. var = helper.create_parameter(
  571. attr=ParamAttr(name='anchors_{}'.format(idx)),
  572. shape=anchors.shape,
  573. dtype='float32',
  574. stop_gradient=True,
  575. default_initializer=NumpyArrayInitializer(anchors))
  576. anchor_vars.append(var)
  577. var.persistable = True
  578. self._anchor_vars = anchor_vars
  579. return self._anchor_vars
  580. @register
  581. @serializable
  582. class FCOSBox(object):
  583. __shared__ = ['num_classes']
  584. def __init__(self, num_classes=80):
  585. super(FCOSBox, self).__init__()
  586. self.num_classes = num_classes
  587. def _merge_hw(self, inputs, ch_type="channel_first"):
  588. """
  589. Merge h and w of the feature map into one dimension.
  590. Args:
  591. inputs (Tensor): Tensor of the input feature map
  592. ch_type (str): "channel_first" or "channel_last" style
  593. Return:
  594. new_shape (Tensor): The new shape after h and w merged
  595. """
  596. shape_ = paddle.shape(inputs)
  597. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  598. img_size = hi * wi
  599. img_size.stop_gradient = True
  600. if ch_type == "channel_first":
  601. new_shape = paddle.concat([bs, ch, img_size])
  602. elif ch_type == "channel_last":
  603. new_shape = paddle.concat([bs, img_size, ch])
  604. else:
  605. raise KeyError("Wrong ch_type %s" % ch_type)
  606. new_shape.stop_gradient = True
  607. return new_shape
  608. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  609. scale_factor):
  610. """
  611. Postprocess each layer of the output with corresponding locations.
  612. Args:
  613. locations (Tensor): anchor points for current layer, [H*W, 2]
  614. box_cls (Tensor): categories prediction, [N, C, H, W],
  615. C is the number of classes
  616. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  617. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  618. scale_factor (Tensor): [h_scale, w_scale] for input images
  619. Return:
  620. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  621. C is the number of classes and M is the number of anchor points
  622. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  623. last dimension is [x1, y1, x2, y2]
  624. """
  625. act_shape_cls = self._merge_hw(box_cls)
  626. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  627. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  628. act_shape_reg = self._merge_hw(box_reg)
  629. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  630. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  631. box_reg_decoding = paddle.stack(
  632. [
  633. locations[:, 0] - box_reg_ch_last[:, :, 0],
  634. locations[:, 1] - box_reg_ch_last[:, :, 1],
  635. locations[:, 0] + box_reg_ch_last[:, :, 2],
  636. locations[:, 1] + box_reg_ch_last[:, :, 3]
  637. ],
  638. axis=1)
  639. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  640. act_shape_ctn = self._merge_hw(box_ctn)
  641. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  642. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  643. # recover the location to original image
  644. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  645. box_reg_decoding = box_reg_decoding / im_scale
  646. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  647. return box_cls_ch_last, box_reg_decoding
  648. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  649. scale_factor):
  650. pred_boxes_ = []
  651. pred_scores_ = []
  652. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  653. centerness):
  654. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  655. pts, cls, box, ctn, scale_factor)
  656. pred_boxes_.append(pred_boxes_lvl)
  657. pred_scores_.append(pred_scores_lvl)
  658. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  659. pred_scores = paddle.concat(pred_scores_, axis=2)
  660. return pred_boxes, pred_scores
  661. @register
  662. class TTFBox(object):
  663. __shared__ = ['down_ratio']
  664. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  665. super(TTFBox, self).__init__()
  666. self.max_per_img = max_per_img
  667. self.score_thresh = score_thresh
  668. self.down_ratio = down_ratio
  669. def _simple_nms(self, heat, kernel=3):
  670. """
  671. Use maxpool to filter the max score, get local peaks.
  672. """
  673. pad = (kernel - 1) // 2
  674. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  675. keep = paddle.cast(hmax == heat, 'float32')
  676. return heat * keep
  677. def _topk(self, scores):
  678. """
  679. Select top k scores and decode to get xy coordinates.
  680. """
  681. k = self.max_per_img
  682. shape_fm = paddle.shape(scores)
  683. shape_fm.stop_gradient = True
  684. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  685. # batch size is 1
  686. scores_r = paddle.reshape(scores, [cat, -1])
  687. topk_scores, topk_inds = paddle.topk(scores_r, k)
  688. topk_scores, topk_inds = paddle.topk(scores_r, k)
  689. topk_ys = topk_inds // width
  690. topk_xs = topk_inds % width
  691. topk_score_r = paddle.reshape(topk_scores, [-1])
  692. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  693. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  694. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  695. topk_inds = paddle.reshape(topk_inds, [-1])
  696. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  697. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  698. topk_inds = paddle.gather(topk_inds, topk_ind)
  699. topk_ys = paddle.gather(topk_ys, topk_ind)
  700. topk_xs = paddle.gather(topk_xs, topk_ind)
  701. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  702. def __call__(self, hm, wh, im_shape, scale_factor):
  703. heatmap = F.sigmoid(hm)
  704. heat = self._simple_nms(heatmap)
  705. scores, inds, clses, ys, xs = self._topk(heat)
  706. ys = paddle.cast(ys, 'float32') * self.down_ratio
  707. xs = paddle.cast(xs, 'float32') * self.down_ratio
  708. scores = paddle.tensor.unsqueeze(scores, [1])
  709. clses = paddle.tensor.unsqueeze(clses, [1])
  710. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  711. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  712. wh = paddle.gather(wh, inds)
  713. x1 = xs - wh[:, 0:1]
  714. y1 = ys - wh[:, 1:2]
  715. x2 = xs + wh[:, 2:3]
  716. y2 = ys + wh[:, 3:4]
  717. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  718. scale_y = scale_factor[:, 0:1]
  719. scale_x = scale_factor[:, 1:2]
  720. scale_expand = paddle.concat(
  721. [scale_x, scale_y, scale_x, scale_y], axis=1)
  722. boxes_shape = paddle.shape(bboxes)
  723. boxes_shape.stop_gradient = True
  724. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  725. bboxes = paddle.divide(bboxes, scale_expand)
  726. results = paddle.concat([clses, scores, bboxes], axis=1)
  727. # hack: append result with cls=-1 and score=1. to avoid all scores
  728. # are less than score_thresh which may cause error in gather.
  729. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  730. fill_r = paddle.cast(fill_r, results.dtype)
  731. results = paddle.concat([results, fill_r])
  732. scores = results[:, 1]
  733. valid_ind = paddle.nonzero(scores > self.score_thresh)
  734. results = paddle.gather(results, valid_ind)
  735. return results, paddle.shape(results)[0:1]
  736. @register
  737. @serializable
  738. class JDEBox(object):
  739. __shared__ = ['num_classes']
  740. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  741. self.num_classes = num_classes
  742. self.conf_thresh = conf_thresh
  743. self.downsample_ratio = downsample_ratio
  744. def generate_anchor(self, nGh, nGw, anchor_wh):
  745. nA = len(anchor_wh)
  746. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  747. mesh = paddle.stack(
  748. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  749. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  750. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  751. int(nGh), axis=-2).repeat(
  752. int(nGw), axis=-1)
  753. anchor_offset_mesh = paddle.to_tensor(
  754. anchor_offset_mesh.astype(np.float32))
  755. # nA x 2 x nGh x nGw
  756. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  757. anchor_mesh = paddle.transpose(anchor_mesh,
  758. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  759. return anchor_mesh
  760. def decode_delta(self, delta, fg_anchor_list):
  761. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  762. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  763. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  764. gx = pw * dx + px
  765. gy = ph * dy + py
  766. gw = pw * paddle.exp(dw)
  767. gh = ph * paddle.exp(dh)
  768. gx1 = gx - gw * 0.5
  769. gy1 = gy - gh * 0.5
  770. gx2 = gx + gw * 0.5
  771. gy2 = gy + gh * 0.5
  772. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  773. def decode_delta_map(self, delta_map, anchors):
  774. delta_map_shape = paddle.shape(delta_map)
  775. delta_map_shape.stop_gradient = True
  776. nB, nA, nGh, nGw, _ = delta_map_shape[:]
  777. anchor_mesh = self.generate_anchor(nGh, nGw, anchors)
  778. # only support bs=1
  779. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  780. pred_list = self.decode_delta(
  781. paddle.reshape(
  782. delta_map, shape=[-1, 4]),
  783. paddle.reshape(
  784. anchor_mesh, shape=[-1, 4]))
  785. pred_map = paddle.reshape(pred_list, shape=[nB, -1, 4])
  786. return pred_map
  787. def __call__(self, yolo_head_out, anchors):
  788. bbox_pred_list = []
  789. for i, head_out in enumerate(yolo_head_out):
  790. stride = self.downsample_ratio // 2**i
  791. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  792. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  793. nA = len(anc_w)
  794. boxes_shape = paddle.shape(head_out)
  795. boxes_shape.stop_gradient = True
  796. nB, nGh, nGw = boxes_shape[0], boxes_shape[-2], boxes_shape[-1]
  797. p = head_out.reshape((nB, nA, self.num_classes + 5, nGh, nGw))
  798. p = paddle.transpose(p, perm=[0, 1, 3, 4, 2]) # [nB, 4, 19, 34, 6]
  799. p_box = p[:, :, :, :, :4] # [nB, 4, 19, 34, 4]
  800. boxes = self.decode_delta_map(p_box,
  801. anchor_vec) # [nB, 4*19*34, 4]
  802. boxes = boxes * stride
  803. p_conf = paddle.transpose(
  804. p[:, :, :, :, 4:6], perm=[0, 4, 1, 2, 3]) # [nB, 2, 4, 19, 34]
  805. p_conf = F.softmax(
  806. p_conf,
  807. axis=1)[:, 1, :, :, :].unsqueeze(-1) # [nB, 4, 19, 34, 1]
  808. scores = paddle.reshape(p_conf, shape=[nB, -1, 1])
  809. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  810. yolo_boxes_pred = paddle.concat(bbox_pred_list, axis=1)
  811. boxes_idx = paddle.nonzero(
  812. yolo_boxes_pred[:, :, -1] > self.conf_thresh)
  813. boxes_idx.stop_gradient = True
  814. if boxes_idx.shape[0] == 0: # TODO: deploy
  815. boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  816. yolo_boxes_out = paddle.to_tensor(
  817. np.array(
  818. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  819. yolo_scores_out = paddle.to_tensor(
  820. np.array(
  821. [[[0.0]]], dtype='float32'))
  822. return boxes_idx, yolo_boxes_out, yolo_scores_out
  823. yolo_boxes = paddle.gather_nd(yolo_boxes_pred, boxes_idx)
  824. yolo_boxes_out = paddle.reshape(yolo_boxes[:, :4], shape=[nB, -1, 4])
  825. yolo_scores_out = paddle.reshape(yolo_boxes[:, 4:5], shape=[nB, 1, -1])
  826. boxes_idx = boxes_idx[:, 1:]
  827. return boxes_idx, yolo_boxes_out, yolo_scores_out # [163], [1, 163, 4], [1, 1, 163]
  828. @register
  829. @serializable
  830. class MaskMatrixNMS(object):
  831. """
  832. Matrix NMS for multi-class masks.
  833. Args:
  834. update_threshold (float): Updated threshold of categroy score in second time.
  835. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  836. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  837. kernel (str): 'linear' or 'gaussian'.
  838. sigma (float): std in gaussian method.
  839. Input:
  840. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  841. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  842. cate_labels (Variable): shape (n), mask labels in descending order
  843. cate_scores (Variable): shape (n), mask scores in descending order
  844. sum_masks (Variable): a float tensor of the sum of seg_masks
  845. Returns:
  846. Variable: cate_scores, tensors of shape (n)
  847. """
  848. def __init__(self,
  849. update_threshold=0.05,
  850. pre_nms_top_n=500,
  851. post_nms_top_n=100,
  852. kernel='gaussian',
  853. sigma=2.0):
  854. super(MaskMatrixNMS, self).__init__()
  855. self.update_threshold = update_threshold
  856. self.pre_nms_top_n = pre_nms_top_n
  857. self.post_nms_top_n = post_nms_top_n
  858. self.kernel = kernel
  859. self.sigma = sigma
  860. def _sort_score(self, scores, top_num):
  861. if paddle.shape(scores)[0] > top_num:
  862. return paddle.topk(scores, top_num)[1]
  863. else:
  864. return paddle.argsort(scores, descending=True)
  865. def __call__(self,
  866. seg_preds,
  867. seg_masks,
  868. cate_labels,
  869. cate_scores,
  870. sum_masks=None):
  871. # sort and keep top nms_pre
  872. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  873. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  874. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  875. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  876. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  877. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  878. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  879. # inter.
  880. inter_matrix = paddle.mm(seg_masks,
  881. paddle.transpose(seg_masks, [1, 0]))
  882. n_samples = paddle.shape(cate_labels)
  883. # union.
  884. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  885. # iou.
  886. iou_matrix = (inter_matrix / (
  887. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)
  888. )
  889. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  890. # label_specific matrix.
  891. cate_labels_x = paddle.expand(
  892. cate_labels, shape=[n_samples, n_samples])
  893. label_matrix = paddle.cast(
  894. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  895. 'float32')
  896. label_matrix = paddle.triu(label_matrix, diagonal=1)
  897. # IoU compensation
  898. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  899. compensate_iou = paddle.expand(
  900. compensate_iou, shape=[n_samples, n_samples])
  901. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  902. # IoU decay
  903. decay_iou = iou_matrix * label_matrix
  904. # matrix nms
  905. if self.kernel == 'gaussian':
  906. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  907. compensate_matrix = paddle.exp(-1 * self.sigma *
  908. (compensate_iou**2))
  909. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  910. axis=0)
  911. elif self.kernel == 'linear':
  912. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  913. decay_coefficient = paddle.min(decay_matrix, axis=0)
  914. else:
  915. raise NotImplementedError
  916. # update the score.
  917. cate_scores = cate_scores * decay_coefficient
  918. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  919. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  920. y)
  921. keep = paddle.nonzero(keep)
  922. keep = paddle.squeeze(keep, axis=[1])
  923. # Prevent empty and increase fake data
  924. keep = paddle.concat(
  925. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  926. seg_preds = paddle.gather(seg_preds, index=keep)
  927. cate_scores = paddle.gather(cate_scores, index=keep)
  928. cate_labels = paddle.gather(cate_labels, index=keep)
  929. # sort and keep top_k
  930. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  931. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  932. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  933. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  934. return seg_preds, cate_scores, cate_labels
  935. def Conv2d(in_channels,
  936. out_channels,
  937. kernel_size,
  938. stride=1,
  939. padding=0,
  940. dilation=1,
  941. groups=1,
  942. bias=True,
  943. weight_init=Normal(std=0.001),
  944. bias_init=Constant(0.)):
  945. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  946. if bias:
  947. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  948. else:
  949. bias_attr = False
  950. conv = nn.Conv2D(
  951. in_channels,
  952. out_channels,
  953. kernel_size,
  954. stride,
  955. padding,
  956. dilation,
  957. groups,
  958. weight_attr=weight_attr,
  959. bias_attr=bias_attr)
  960. return conv
  961. def ConvTranspose2d(in_channels,
  962. out_channels,
  963. kernel_size,
  964. stride=1,
  965. padding=0,
  966. output_padding=0,
  967. groups=1,
  968. bias=True,
  969. dilation=1,
  970. weight_init=Normal(std=0.001),
  971. bias_init=Constant(0.)):
  972. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  973. if bias:
  974. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  975. else:
  976. bias_attr = False
  977. conv = nn.Conv2DTranspose(
  978. in_channels,
  979. out_channels,
  980. kernel_size,
  981. stride,
  982. padding,
  983. output_padding,
  984. dilation,
  985. groups,
  986. weight_attr=weight_attr,
  987. bias_attr=bias_attr)
  988. return conv
  989. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  990. if not affine:
  991. weight_attr = False
  992. bias_attr = False
  993. else:
  994. weight_attr = None
  995. bias_attr = None
  996. batchnorm = nn.BatchNorm2D(
  997. num_features,
  998. momentum,
  999. eps,
  1000. weight_attr=weight_attr,
  1001. bias_attr=bias_attr)
  1002. return batchnorm
  1003. def ReLU():
  1004. return nn.ReLU()
  1005. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  1006. return nn.Upsample(None, scale_factor, mode, align_corners)
  1007. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1008. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1009. class Concat(nn.Layer):
  1010. def __init__(self, dim=0):
  1011. super(Concat, self).__init__()
  1012. self.dim = dim
  1013. def forward(self, inputs):
  1014. return paddle.concat(inputs, axis=self.dim)
  1015. def extra_repr(self):
  1016. return 'dim={}'.format(self.dim)