detector.py 93 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import numpy as np
  20. import paddle
  21. from paddle.static import InputSpec
  22. import paddlex.ppdet as ppdet
  23. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  24. import paddlex
  25. import paddlex.utils.logging as logging
  26. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  27. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
  28. _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from paddlex.ppdet.optimizer import ModelEMA
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN",
  36. "PicoDet"
  37. ]
  38. class BaseDetector(BaseModel):
  39. def __init__(self, model_name, num_classes=80, **params):
  40. self.init_params.update(locals())
  41. if 'with_net' in self.init_params:
  42. del self.init_params['with_net']
  43. super(BaseDetector, self).__init__('detector')
  44. if not hasattr(ppdet.modeling, model_name):
  45. raise Exception("ERROR: There's no model named {}.".format(
  46. model_name))
  47. self.model_name = model_name
  48. self.num_classes = num_classes
  49. self.labels = None
  50. if params.get('with_net', True):
  51. params.pop('with_net', None)
  52. self.net = self.build_net(**params)
  53. def build_net(self, **params):
  54. with paddle.utils.unique_name.guard():
  55. net = ppdet.modeling.__dict__[self.model_name](**params)
  56. return net
  57. def _fix_transforms_shape(self, image_shape):
  58. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  59. def _define_input_spec(self, image_shape):
  60. input_spec = [{
  61. "image": InputSpec(
  62. shape=image_shape, name='image', dtype='float32'),
  63. "im_shape": InputSpec(
  64. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  65. "scale_factor": InputSpec(
  66. shape=[image_shape[0], 2],
  67. name='scale_factor',
  68. dtype='float32')
  69. }]
  70. return input_spec
  71. def _check_image_shape(self, image_shape):
  72. if len(image_shape) == 2:
  73. image_shape = [1, 3] + image_shape
  74. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  75. raise Exception(
  76. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  77. format(image_shape[-2:]))
  78. return image_shape
  79. def _get_test_inputs(self, image_shape):
  80. if image_shape is not None:
  81. image_shape = self._check_image_shape(image_shape)
  82. self._fix_transforms_shape(image_shape[-2:])
  83. else:
  84. image_shape = [None, 3, -1, -1]
  85. self.fixed_input_shape = image_shape
  86. return self._define_input_spec(image_shape)
  87. def _get_backbone(self, backbone_name, **params):
  88. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  89. return backbone
  90. def run(self, net, inputs, mode):
  91. net_out = net(inputs)
  92. if mode in ['train', 'eval']:
  93. outputs = net_out
  94. else:
  95. outputs = dict()
  96. for key in net_out:
  97. outputs[key] = net_out[key].numpy()
  98. return outputs
  99. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  100. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  101. num_steps_each_epoch):
  102. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  103. values = [(lr_decay_gamma**i) * learning_rate
  104. for i in range(len(lr_decay_epochs) + 1)]
  105. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  106. boundaries=boundaries, values=values)
  107. if warmup_steps > 0:
  108. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  109. logging.error(
  110. "In function train(), parameters should satisfy: "
  111. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  112. exit=False)
  113. logging.error(
  114. "See this doc for more information: "
  115. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  116. exit=False)
  117. scheduler = paddle.optimizer.lr.LinearWarmup(
  118. learning_rate=scheduler,
  119. warmup_steps=warmup_steps,
  120. start_lr=warmup_start_lr,
  121. end_lr=learning_rate)
  122. optimizer = paddle.optimizer.Momentum(
  123. scheduler,
  124. momentum=.9,
  125. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  126. parameters=parameters)
  127. return optimizer
  128. def train(self,
  129. num_epochs,
  130. train_dataset,
  131. train_batch_size=64,
  132. eval_dataset=None,
  133. optimizer=None,
  134. save_interval_epochs=1,
  135. log_interval_steps=10,
  136. save_dir='output',
  137. pretrain_weights='IMAGENET',
  138. learning_rate=.001,
  139. warmup_steps=0,
  140. warmup_start_lr=0.0,
  141. lr_decay_epochs=(216, 243),
  142. lr_decay_gamma=0.1,
  143. metric=None,
  144. use_ema=False,
  145. early_stop=False,
  146. early_stop_patience=5,
  147. use_vdl=True,
  148. resume_checkpoint=None):
  149. """
  150. Train the model.
  151. Args:
  152. num_epochs(int): The number of epochs.
  153. train_dataset(paddlex.dataset): Training dataset.
  154. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  155. eval_dataset(paddlex.dataset, optional):
  156. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  157. optimizer(paddle.optimizer.Optimizer or None, optional):
  158. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  159. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  160. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  161. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  162. pretrain_weights(str or None, optional):
  163. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  164. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  165. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  166. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  167. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  168. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  169. metric({'VOC', 'COCO', None}, optional):
  170. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  171. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  172. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  173. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  174. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  175. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  176. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  177. `pretrain_weights` can be set simultaneously. Defaults to None.
  178. """
  179. if self.status == 'Infer':
  180. logging.error(
  181. "Exported inference model does not support training.",
  182. exit=True)
  183. if pretrain_weights is not None and resume_checkpoint is not None:
  184. logging.error(
  185. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  186. exit=True)
  187. if train_dataset.__class__.__name__ == 'VOCDetection':
  188. train_dataset.data_fields = {
  189. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  190. 'difficult'
  191. }
  192. elif train_dataset.__class__.__name__ == 'CocoDetection':
  193. if self.__class__.__name__ == 'MaskRCNN':
  194. train_dataset.data_fields = {
  195. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  196. 'gt_poly', 'is_crowd'
  197. }
  198. else:
  199. train_dataset.data_fields = {
  200. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  201. 'is_crowd'
  202. }
  203. if metric is None:
  204. if eval_dataset.__class__.__name__ == 'VOCDetection':
  205. self.metric = 'voc'
  206. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  207. self.metric = 'coco'
  208. else:
  209. assert metric.lower() in ['coco', 'voc'], \
  210. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  211. self.metric = metric.lower()
  212. self.labels = train_dataset.labels
  213. self.num_max_boxes = train_dataset.num_max_boxes
  214. train_dataset.batch_transforms = self._compose_batch_transform(
  215. train_dataset.transforms, mode='train')
  216. # build optimizer if not defined
  217. if optimizer is None:
  218. num_steps_each_epoch = len(train_dataset) // train_batch_size
  219. self.optimizer = self.default_optimizer(
  220. parameters=self.net.parameters(),
  221. learning_rate=learning_rate,
  222. warmup_steps=warmup_steps,
  223. warmup_start_lr=warmup_start_lr,
  224. lr_decay_epochs=lr_decay_epochs,
  225. lr_decay_gamma=lr_decay_gamma,
  226. num_steps_each_epoch=num_steps_each_epoch)
  227. else:
  228. self.optimizer = optimizer
  229. # initiate weights
  230. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  231. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  232. [self.model_name, self.backbone_name])]:
  233. logging.warning(
  234. "Path of pretrain_weights('{}') does not exist!".format(
  235. pretrain_weights))
  236. pretrain_weights = det_pretrain_weights_dict['_'.join(
  237. [self.model_name, self.backbone_name])][0]
  238. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  239. "If you don't want to use pretrain weights, "
  240. "set pretrain_weights to be None.".format(
  241. pretrain_weights))
  242. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  243. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  244. logging.error(
  245. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  246. exit=True)
  247. pretrained_dir = osp.join(save_dir, 'pretrain')
  248. self.net_initialize(
  249. pretrain_weights=pretrain_weights,
  250. save_dir=pretrained_dir,
  251. resume_checkpoint=resume_checkpoint,
  252. is_backbone_weights=(pretrain_weights == 'IMAGENET' and
  253. 'ESNet_' in self.backbone_name))
  254. if use_ema:
  255. ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True)
  256. else:
  257. ema = None
  258. # start train loop
  259. self.train_loop(
  260. num_epochs=num_epochs,
  261. train_dataset=train_dataset,
  262. train_batch_size=train_batch_size,
  263. eval_dataset=eval_dataset,
  264. save_interval_epochs=save_interval_epochs,
  265. log_interval_steps=log_interval_steps,
  266. save_dir=save_dir,
  267. ema=ema,
  268. early_stop=early_stop,
  269. early_stop_patience=early_stop_patience,
  270. use_vdl=use_vdl)
  271. def quant_aware_train(self,
  272. num_epochs,
  273. train_dataset,
  274. train_batch_size=64,
  275. eval_dataset=None,
  276. optimizer=None,
  277. save_interval_epochs=1,
  278. log_interval_steps=10,
  279. save_dir='output',
  280. learning_rate=.00001,
  281. warmup_steps=0,
  282. warmup_start_lr=0.0,
  283. lr_decay_epochs=(216, 243),
  284. lr_decay_gamma=0.1,
  285. metric=None,
  286. use_ema=False,
  287. early_stop=False,
  288. early_stop_patience=5,
  289. use_vdl=True,
  290. resume_checkpoint=None,
  291. quant_config=None):
  292. """
  293. Quantization-aware training.
  294. Args:
  295. num_epochs(int): The number of epochs.
  296. train_dataset(paddlex.dataset): Training dataset.
  297. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  298. eval_dataset(paddlex.dataset, optional):
  299. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  300. optimizer(paddle.optimizer.Optimizer or None, optional):
  301. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  302. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  303. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  304. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  305. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  306. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  307. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  308. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  309. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  310. metric({'VOC', 'COCO', None}, optional):
  311. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  312. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  313. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  314. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  315. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  316. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  317. configuration will be used. Defaults to None.
  318. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  319. from. If None, no training checkpoint will be resumed. Defaults to None.
  320. """
  321. self._prepare_qat(quant_config)
  322. self.train(
  323. num_epochs=num_epochs,
  324. train_dataset=train_dataset,
  325. train_batch_size=train_batch_size,
  326. eval_dataset=eval_dataset,
  327. optimizer=optimizer,
  328. save_interval_epochs=save_interval_epochs,
  329. log_interval_steps=log_interval_steps,
  330. save_dir=save_dir,
  331. pretrain_weights=None,
  332. learning_rate=learning_rate,
  333. warmup_steps=warmup_steps,
  334. warmup_start_lr=warmup_start_lr,
  335. lr_decay_epochs=lr_decay_epochs,
  336. lr_decay_gamma=lr_decay_gamma,
  337. metric=metric,
  338. use_ema=use_ema,
  339. early_stop=early_stop,
  340. early_stop_patience=early_stop_patience,
  341. use_vdl=use_vdl,
  342. resume_checkpoint=resume_checkpoint)
  343. def evaluate(self,
  344. eval_dataset,
  345. batch_size=1,
  346. metric=None,
  347. return_details=False):
  348. """
  349. Evaluate the model.
  350. Args:
  351. eval_dataset(paddlex.dataset): Evaluation dataset.
  352. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  353. metric({'VOC', 'COCO', None}, optional):
  354. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  355. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  356. Returns:
  357. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  358. """
  359. if metric is None:
  360. if not hasattr(self, 'metric'):
  361. if eval_dataset.__class__.__name__ == 'VOCDetection':
  362. self.metric = 'voc'
  363. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  364. self.metric = 'coco'
  365. else:
  366. assert metric.lower() in ['coco', 'voc'], \
  367. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  368. self.metric = metric.lower()
  369. if self.metric == 'voc':
  370. eval_dataset.data_fields = {
  371. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  372. 'difficult'
  373. }
  374. elif self.metric == 'coco':
  375. if self.__class__.__name__ == 'MaskRCNN':
  376. eval_dataset.data_fields = {
  377. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  378. 'gt_poly', 'is_crowd'
  379. }
  380. else:
  381. eval_dataset.data_fields = {
  382. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  383. 'is_crowd'
  384. }
  385. eval_dataset.batch_transforms = self._compose_batch_transform(
  386. eval_dataset.transforms, mode='eval')
  387. arrange_transforms(
  388. model_type=self.model_type,
  389. transforms=eval_dataset.transforms,
  390. mode='eval')
  391. self.net.eval()
  392. nranks = paddle.distributed.get_world_size()
  393. local_rank = paddle.distributed.get_rank()
  394. if nranks > 1:
  395. # Initialize parallel environment if not done.
  396. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  397. ):
  398. paddle.distributed.init_parallel_env()
  399. if batch_size > 1:
  400. logging.warning(
  401. "Detector only supports single card evaluation with batch_size=1 "
  402. "during evaluation, so batch_size is forcibly set to 1.")
  403. batch_size = 1
  404. if nranks < 2 or local_rank == 0:
  405. self.eval_data_loader = self.build_data_loader(
  406. eval_dataset, batch_size=batch_size, mode='eval')
  407. is_bbox_normalized = False
  408. if eval_dataset.batch_transforms is not None:
  409. is_bbox_normalized = any(
  410. isinstance(t, _NormalizeBox)
  411. for t in eval_dataset.batch_transforms.batch_transforms)
  412. if self.metric == 'voc':
  413. eval_metric = VOCMetric(
  414. labels=eval_dataset.labels,
  415. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  416. is_bbox_normalized=is_bbox_normalized,
  417. classwise=False)
  418. else:
  419. eval_metric = COCOMetric(
  420. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  421. classwise=False)
  422. scores = collections.OrderedDict()
  423. logging.info(
  424. "Start to evaluate(total_samples={}, total_steps={})...".
  425. format(eval_dataset.num_samples, eval_dataset.num_samples))
  426. with paddle.no_grad():
  427. for step, data in enumerate(self.eval_data_loader):
  428. outputs = self.run(self.net, data, 'eval')
  429. eval_metric.update(data, outputs)
  430. eval_metric.accumulate()
  431. self.eval_details = eval_metric.details
  432. scores.update(eval_metric.get())
  433. eval_metric.reset()
  434. if return_details:
  435. return scores, self.eval_details
  436. return scores
  437. def predict(self, img_file, transforms=None):
  438. """
  439. Do inference.
  440. Args:
  441. img_file(List[np.ndarray or str], str or np.ndarray):
  442. Image path or decoded image data in a BGR format, which also could constitute a list,
  443. meaning all images to be predicted as a mini-batch.
  444. transforms(paddlex.transforms.Compose or None, optional):
  445. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  446. Returns:
  447. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  448. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  449. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  450. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  451. category(str): category name
  452. bbox(list): bounding box in [x, y, w, h] format
  453. score(str): confidence
  454. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  455. """
  456. if transforms is None and not hasattr(self, 'test_transforms'):
  457. raise Exception("transforms need to be defined, now is None.")
  458. if transforms is None:
  459. transforms = self.test_transforms
  460. if isinstance(img_file, (str, np.ndarray)):
  461. images = [img_file]
  462. else:
  463. images = img_file
  464. batch_samples = self._preprocess(images, transforms)
  465. self.net.eval()
  466. outputs = self.run(self.net, batch_samples, 'test')
  467. prediction = self._postprocess(outputs)
  468. if isinstance(img_file, (str, np.ndarray)):
  469. prediction = prediction[0]
  470. return prediction
  471. def _preprocess(self, images, transforms, to_tensor=True):
  472. arrange_transforms(
  473. model_type=self.model_type, transforms=transforms, mode='test')
  474. batch_samples = list()
  475. for im in images:
  476. sample = {'image': im}
  477. batch_samples.append(transforms(sample))
  478. batch_transforms = self._compose_batch_transform(transforms, 'test')
  479. batch_samples = batch_transforms(batch_samples)
  480. if to_tensor:
  481. for k in batch_samples:
  482. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  483. return batch_samples
  484. def _postprocess(self, batch_pred):
  485. infer_result = {}
  486. if 'bbox' in batch_pred:
  487. bboxes = batch_pred['bbox']
  488. bbox_nums = batch_pred['bbox_num']
  489. det_res = []
  490. k = 0
  491. for i in range(len(bbox_nums)):
  492. det_nums = bbox_nums[i]
  493. for j in range(det_nums):
  494. dt = bboxes[k]
  495. k = k + 1
  496. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  497. if int(num_id) < 0:
  498. continue
  499. category = self.labels[int(num_id)]
  500. w = xmax - xmin
  501. h = ymax - ymin
  502. bbox = [xmin, ymin, w, h]
  503. dt_res = {
  504. 'category_id': int(num_id),
  505. 'category': category,
  506. 'bbox': bbox,
  507. 'score': score
  508. }
  509. det_res.append(dt_res)
  510. infer_result['bbox'] = det_res
  511. if 'mask' in batch_pred:
  512. masks = batch_pred['mask']
  513. bboxes = batch_pred['bbox']
  514. mask_nums = batch_pred['bbox_num']
  515. seg_res = []
  516. k = 0
  517. for i in range(len(mask_nums)):
  518. det_nums = mask_nums[i]
  519. for j in range(det_nums):
  520. mask = masks[k].astype(np.uint8)
  521. score = float(bboxes[k][1])
  522. label = int(bboxes[k][0])
  523. k = k + 1
  524. if label == -1:
  525. continue
  526. category = self.labels[int(label)]
  527. sg_res = {
  528. 'category_id': int(label),
  529. 'category': category,
  530. 'mask': mask.astype('uint8'),
  531. 'score': score
  532. }
  533. seg_res.append(sg_res)
  534. infer_result['mask'] = seg_res
  535. bbox_num = batch_pred['bbox_num']
  536. results = []
  537. start = 0
  538. for num in bbox_num:
  539. end = start + num
  540. curr_res = infer_result['bbox'][start:end]
  541. if 'mask' in infer_result:
  542. mask_res = infer_result['mask'][start:end]
  543. for box, mask in zip(curr_res, mask_res):
  544. box.update(mask)
  545. results.append(curr_res)
  546. start = end
  547. return results
  548. class PicoDet(BaseDetector):
  549. def __init__(self,
  550. num_classes=80,
  551. backbone='ESNet_m',
  552. nms_score_threshold=.025,
  553. nms_topk=1000,
  554. nms_keep_topk=100,
  555. nms_iou_threshold=.6,
  556. **params):
  557. self.init_params = locals()
  558. if backbone not in {
  559. 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
  560. 'ResNet18_vd'
  561. }:
  562. raise ValueError(
  563. "backbone: {} is not supported. Please choose one of "
  564. "('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd')".
  565. format(backbone))
  566. self.backbone_name = backbone
  567. if params.get('with_net', True):
  568. if backbone == 'ESNet_s':
  569. backbone = self._get_backbone(
  570. 'ESNet',
  571. scale=.75,
  572. feature_maps=[4, 11, 14],
  573. act="hard_swish",
  574. channel_ratio=[
  575. 0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
  576. 0.5, 0.5, 0.5
  577. ])
  578. neck_out_channels = 96
  579. head_num_convs = 2
  580. elif backbone == 'ESNet_m':
  581. backbone = self._get_backbone(
  582. 'ESNet',
  583. scale=1.0,
  584. feature_maps=[4, 11, 14],
  585. act="hard_swish",
  586. channel_ratio=[
  587. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  588. 0.625, 1.0, 0.625, 0.75
  589. ])
  590. neck_out_channels = 128
  591. head_num_convs = 4
  592. elif backbone == 'ESNet_l':
  593. backbone = self._get_backbone(
  594. 'ESNet',
  595. scale=1.25,
  596. feature_maps=[4, 11, 14],
  597. act="hard_swish",
  598. channel_ratio=[
  599. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  600. 0.625, 1.0, 0.625, 0.75
  601. ])
  602. neck_out_channels = 160
  603. head_num_convs = 4
  604. elif backbone == 'LCNet':
  605. backbone = self._get_backbone(
  606. 'LCNet', scale=1.5, feature_maps=[3, 4, 5])
  607. neck_out_channels = 128
  608. head_num_convs = 4
  609. elif backbone == 'MobileNetV3':
  610. backbone = self._get_backbone(
  611. 'MobileNetV3',
  612. scale=1.0,
  613. with_extra_blocks=False,
  614. extra_block_filters=[],
  615. feature_maps=[7, 13, 16])
  616. neck_out_channels = 128
  617. head_num_convs = 4
  618. else:
  619. backbone = self._get_backbone(
  620. 'ResNet',
  621. depth=18,
  622. variant='d',
  623. return_idx=[1, 2, 3],
  624. freeze_at=-1,
  625. freeze_norm=False,
  626. norm_decay=0.)
  627. neck_out_channels = 128
  628. head_num_convs = 4
  629. neck = ppdet.modeling.CSPPAN(
  630. in_channels=[i.channels for i in backbone.out_shape],
  631. out_channels=neck_out_channels,
  632. num_features=4,
  633. num_csp_blocks=1,
  634. use_depthwise=True)
  635. head_conv_feat = ppdet.modeling.PicoFeat(
  636. feat_in=neck_out_channels,
  637. feat_out=neck_out_channels,
  638. num_fpn_stride=4,
  639. num_convs=head_num_convs,
  640. norm_type='bn',
  641. share_cls_reg=True, )
  642. loss_class = ppdet.modeling.VarifocalLoss(
  643. use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
  644. loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
  645. loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
  646. assigner = ppdet.modeling.SimOTAAssigner(
  647. candidate_topk=10, iou_weight=6, num_classes=num_classes)
  648. nms = ppdet.modeling.MultiClassNMS(
  649. nms_top_k=nms_topk,
  650. keep_top_k=nms_keep_topk,
  651. score_threshold=nms_score_threshold,
  652. nms_threshold=nms_iou_threshold)
  653. head = ppdet.modeling.PicoHead(
  654. conv_feat=head_conv_feat,
  655. num_classes=num_classes,
  656. fpn_stride=[8, 16, 32, 64],
  657. prior_prob=0.01,
  658. reg_max=7,
  659. cell_offset=.5,
  660. loss_class=loss_class,
  661. loss_dfl=loss_dfl,
  662. loss_bbox=loss_bbox,
  663. assigner=assigner,
  664. feat_in_chan=neck_out_channels,
  665. nms=nms)
  666. params.update({
  667. 'backbone': backbone,
  668. 'neck': neck,
  669. 'head': head,
  670. })
  671. super(PicoDet, self).__init__(
  672. model_name='PicoDet', num_classes=num_classes, **params)
  673. def _compose_batch_transform(self, transforms, mode='train'):
  674. default_batch_transforms = [_BatchPadding(pad_to_stride=32)]
  675. if mode == 'eval':
  676. collate_batch = True
  677. else:
  678. collate_batch = False
  679. custom_batch_transforms = []
  680. for i, op in enumerate(transforms.transforms):
  681. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  682. if mode != 'train':
  683. raise Exception(
  684. "{} cannot be present in the {} transforms. ".format(
  685. op.__class__.__name__, mode) +
  686. "Please check the {} transforms.".format(mode))
  687. custom_batch_transforms.insert(0, copy.deepcopy(op))
  688. batch_transforms = BatchCompose(
  689. custom_batch_transforms + default_batch_transforms,
  690. collate_batch=collate_batch)
  691. return batch_transforms
  692. def _fix_transforms_shape(self, image_shape):
  693. if getattr(self, 'test_transforms', None):
  694. has_resize_op = False
  695. resize_op_idx = -1
  696. normalize_op_idx = len(self.test_transforms.transforms)
  697. for idx, op in enumerate(self.test_transforms.transforms):
  698. name = op.__class__.__name__
  699. if name == 'Resize':
  700. has_resize_op = True
  701. resize_op_idx = idx
  702. if name == 'Normalize':
  703. normalize_op_idx = idx
  704. if not has_resize_op:
  705. self.test_transforms.transforms.insert(
  706. normalize_op_idx,
  707. Resize(
  708. target_size=image_shape, interp='CUBIC'))
  709. else:
  710. self.test_transforms.transforms[
  711. resize_op_idx].target_size = image_shape
  712. def _get_test_inputs(self, image_shape):
  713. if image_shape is not None:
  714. image_shape = self._check_image_shape(image_shape)
  715. self._fix_transforms_shape(image_shape[-2:])
  716. else:
  717. image_shape = [None, 3, 320, 320]
  718. if getattr(self, 'test_transforms', None):
  719. for idx, op in enumerate(self.test_transforms.transforms):
  720. name = op.__class__.__name__
  721. if name == 'Resize':
  722. image_shape = [None, 3] + list(
  723. self.test_transforms.transforms[idx].target_size)
  724. logging.warning(
  725. '[Important!!!] When exporting inference model for {}, '
  726. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  727. 'Please ensure image shape after transforms is {}, if not, '
  728. 'fixed_input_shape should be specified manually.'
  729. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  730. self.fixed_input_shape = image_shape
  731. return self._define_input_spec(image_shape)
  732. class YOLOv3(BaseDetector):
  733. def __init__(self,
  734. num_classes=80,
  735. backbone='MobileNetV1',
  736. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  737. [59, 119], [116, 90], [156, 198], [373, 326]],
  738. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  739. ignore_threshold=0.7,
  740. nms_score_threshold=0.01,
  741. nms_topk=1000,
  742. nms_keep_topk=100,
  743. nms_iou_threshold=0.45,
  744. label_smooth=False,
  745. **params):
  746. self.init_params = locals()
  747. if backbone not in {
  748. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  749. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  750. }:
  751. raise ValueError(
  752. "backbone: {} is not supported. Please choose one of "
  753. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
  754. "'ResNet50_vd_dcn', 'ResNet34')".format(backbone))
  755. self.backbone_name = backbone
  756. if params.get('with_net', True):
  757. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  758. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  759. norm_type = 'sync_bn'
  760. else:
  761. norm_type = 'bn'
  762. if 'MobileNetV1' in backbone:
  763. norm_type = 'bn'
  764. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  765. elif 'MobileNetV3' in backbone:
  766. backbone = self._get_backbone(
  767. 'MobileNetV3',
  768. norm_type=norm_type,
  769. feature_maps=[7, 13, 16])
  770. elif backbone == 'ResNet50_vd_dcn':
  771. backbone = self._get_backbone(
  772. 'ResNet',
  773. norm_type=norm_type,
  774. variant='d',
  775. return_idx=[1, 2, 3],
  776. dcn_v2_stages=[3],
  777. freeze_at=-1,
  778. freeze_norm=False)
  779. elif backbone == 'ResNet34':
  780. backbone = self._get_backbone(
  781. 'ResNet',
  782. depth=34,
  783. norm_type=norm_type,
  784. return_idx=[1, 2, 3],
  785. freeze_at=-1,
  786. freeze_norm=False,
  787. norm_decay=0.)
  788. else:
  789. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  790. neck = ppdet.modeling.YOLOv3FPN(
  791. norm_type=norm_type,
  792. in_channels=[i.channels for i in backbone.out_shape])
  793. loss = ppdet.modeling.YOLOv3Loss(
  794. num_classes=num_classes,
  795. ignore_thresh=ignore_threshold,
  796. label_smooth=label_smooth)
  797. yolo_head = ppdet.modeling.YOLOv3Head(
  798. in_channels=[i.channels for i in neck.out_shape],
  799. anchors=anchors,
  800. anchor_masks=anchor_masks,
  801. num_classes=num_classes,
  802. loss=loss)
  803. post_process = ppdet.modeling.BBoxPostProcess(
  804. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  805. nms=ppdet.modeling.MultiClassNMS(
  806. score_threshold=nms_score_threshold,
  807. nms_top_k=nms_topk,
  808. keep_top_k=nms_keep_topk,
  809. nms_threshold=nms_iou_threshold))
  810. params.update({
  811. 'backbone': backbone,
  812. 'neck': neck,
  813. 'yolo_head': yolo_head,
  814. 'post_process': post_process
  815. })
  816. super(YOLOv3, self).__init__(
  817. model_name='YOLOv3', num_classes=num_classes, **params)
  818. self.anchors = anchors
  819. self.anchor_masks = anchor_masks
  820. def _compose_batch_transform(self, transforms, mode='train'):
  821. if mode == 'train':
  822. default_batch_transforms = [
  823. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  824. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  825. _Gt2YoloTarget(
  826. anchor_masks=self.anchor_masks,
  827. anchors=self.anchors,
  828. downsample_ratios=getattr(self, 'downsample_ratios',
  829. [32, 16, 8]),
  830. num_classes=self.num_classes)
  831. ]
  832. else:
  833. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  834. if mode == 'eval' and self.metric == 'voc':
  835. collate_batch = False
  836. else:
  837. collate_batch = True
  838. custom_batch_transforms = []
  839. for i, op in enumerate(transforms.transforms):
  840. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  841. if mode != 'train':
  842. raise Exception(
  843. "{} cannot be present in the {} transforms. ".format(
  844. op.__class__.__name__, mode) +
  845. "Please check the {} transforms.".format(mode))
  846. custom_batch_transforms.insert(0, copy.deepcopy(op))
  847. batch_transforms = BatchCompose(
  848. custom_batch_transforms + default_batch_transforms,
  849. collate_batch=collate_batch)
  850. return batch_transforms
  851. def _fix_transforms_shape(self, image_shape):
  852. if getattr(self, 'test_transforms', None):
  853. has_resize_op = False
  854. resize_op_idx = -1
  855. normalize_op_idx = len(self.test_transforms.transforms)
  856. for idx, op in enumerate(self.test_transforms.transforms):
  857. name = op.__class__.__name__
  858. if name == 'Resize':
  859. has_resize_op = True
  860. resize_op_idx = idx
  861. if name == 'Normalize':
  862. normalize_op_idx = idx
  863. if not has_resize_op:
  864. self.test_transforms.transforms.insert(
  865. normalize_op_idx,
  866. Resize(
  867. target_size=image_shape, interp='CUBIC'))
  868. else:
  869. self.test_transforms.transforms[
  870. resize_op_idx].target_size = image_shape
  871. class FasterRCNN(BaseDetector):
  872. def __init__(self,
  873. num_classes=80,
  874. backbone='ResNet50',
  875. with_fpn=True,
  876. with_dcn=False,
  877. aspect_ratios=[0.5, 1.0, 2.0],
  878. anchor_sizes=[[32], [64], [128], [256], [512]],
  879. keep_top_k=100,
  880. nms_threshold=0.5,
  881. score_threshold=0.05,
  882. fpn_num_channels=256,
  883. rpn_batch_size_per_im=256,
  884. rpn_fg_fraction=0.5,
  885. test_pre_nms_top_n=None,
  886. test_post_nms_top_n=1000,
  887. **params):
  888. self.init_params = locals()
  889. if backbone not in {
  890. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  891. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  892. }:
  893. raise ValueError(
  894. "backbone: {} is not supported. Please choose one of "
  895. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  896. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  897. self.backbone_name = backbone
  898. if params.get('with_net', True):
  899. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  900. if backbone == 'HRNet_W18':
  901. if not with_fpn:
  902. logging.warning(
  903. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  904. format(backbone))
  905. with_fpn = True
  906. if with_dcn:
  907. logging.warning(
  908. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  909. format(backbone))
  910. backbone = self._get_backbone(
  911. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  912. elif backbone == 'ResNet50_vd_ssld':
  913. if not with_fpn:
  914. logging.warning(
  915. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  916. format(backbone))
  917. with_fpn = True
  918. backbone = self._get_backbone(
  919. 'ResNet',
  920. variant='d',
  921. norm_type='bn',
  922. freeze_at=0,
  923. return_idx=[0, 1, 2, 3],
  924. num_stages=4,
  925. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  926. dcn_v2_stages=dcn_v2_stages)
  927. elif 'ResNet50' in backbone:
  928. if with_fpn:
  929. backbone = self._get_backbone(
  930. 'ResNet',
  931. variant='d' if '_vd' in backbone else 'b',
  932. norm_type='bn',
  933. freeze_at=0,
  934. return_idx=[0, 1, 2, 3],
  935. num_stages=4,
  936. dcn_v2_stages=dcn_v2_stages)
  937. else:
  938. if with_dcn:
  939. logging.warning(
  940. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  941. format(backbone))
  942. backbone = self._get_backbone(
  943. 'ResNet',
  944. variant='d' if '_vd' in backbone else 'b',
  945. norm_type='bn',
  946. freeze_at=0,
  947. return_idx=[2],
  948. num_stages=3)
  949. elif 'ResNet34' in backbone:
  950. if not with_fpn:
  951. logging.warning(
  952. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  953. format(backbone))
  954. with_fpn = True
  955. backbone = self._get_backbone(
  956. 'ResNet',
  957. depth=34,
  958. variant='d' if 'vd' in backbone else 'b',
  959. norm_type='bn',
  960. freeze_at=0,
  961. return_idx=[0, 1, 2, 3],
  962. num_stages=4,
  963. dcn_v2_stages=dcn_v2_stages)
  964. else:
  965. if not with_fpn:
  966. logging.warning(
  967. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  968. format(backbone))
  969. with_fpn = True
  970. backbone = self._get_backbone(
  971. 'ResNet',
  972. depth=101,
  973. variant='d' if 'vd' in backbone else 'b',
  974. norm_type='bn',
  975. freeze_at=0,
  976. return_idx=[0, 1, 2, 3],
  977. num_stages=4,
  978. dcn_v2_stages=dcn_v2_stages)
  979. rpn_in_channel = backbone.out_shape[0].channels
  980. if with_fpn:
  981. self.backbone_name = self.backbone_name + '_fpn'
  982. if 'HRNet' in self.backbone_name:
  983. neck = ppdet.modeling.HRFPN(
  984. in_channels=[i.channels for i in backbone.out_shape],
  985. out_channel=fpn_num_channels,
  986. spatial_scales=[
  987. 1.0 / i.stride for i in backbone.out_shape
  988. ],
  989. share_conv=False)
  990. else:
  991. neck = ppdet.modeling.FPN(
  992. in_channels=[i.channels for i in backbone.out_shape],
  993. out_channel=fpn_num_channels,
  994. spatial_scales=[
  995. 1.0 / i.stride for i in backbone.out_shape
  996. ])
  997. rpn_in_channel = neck.out_shape[0].channels
  998. anchor_generator_cfg = {
  999. 'aspect_ratios': aspect_ratios,
  1000. 'anchor_sizes': anchor_sizes,
  1001. 'strides': [4, 8, 16, 32, 64]
  1002. }
  1003. train_proposal_cfg = {
  1004. 'min_size': 0.0,
  1005. 'nms_thresh': .7,
  1006. 'pre_nms_top_n': 2000,
  1007. 'post_nms_top_n': 1000,
  1008. 'topk_after_collect': True
  1009. }
  1010. test_proposal_cfg = {
  1011. 'min_size': 0.0,
  1012. 'nms_thresh': .7,
  1013. 'pre_nms_top_n': 1000
  1014. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1015. 'post_nms_top_n': test_post_nms_top_n
  1016. }
  1017. head = ppdet.modeling.TwoFCHead(
  1018. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1019. roi_extractor_cfg = {
  1020. 'resolution': 7,
  1021. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1022. 'sampling_ratio': 0,
  1023. 'aligned': True
  1024. }
  1025. with_pool = False
  1026. else:
  1027. neck = None
  1028. anchor_generator_cfg = {
  1029. 'aspect_ratios': aspect_ratios,
  1030. 'anchor_sizes': anchor_sizes,
  1031. 'strides': [16]
  1032. }
  1033. train_proposal_cfg = {
  1034. 'min_size': 0.0,
  1035. 'nms_thresh': .7,
  1036. 'pre_nms_top_n': 12000,
  1037. 'post_nms_top_n': 2000,
  1038. 'topk_after_collect': False
  1039. }
  1040. test_proposal_cfg = {
  1041. 'min_size': 0.0,
  1042. 'nms_thresh': .7,
  1043. 'pre_nms_top_n': 6000
  1044. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1045. 'post_nms_top_n': test_post_nms_top_n
  1046. }
  1047. head = ppdet.modeling.Res5Head()
  1048. roi_extractor_cfg = {
  1049. 'resolution': 14,
  1050. 'spatial_scale':
  1051. [1. / i.stride for i in backbone.out_shape],
  1052. 'sampling_ratio': 0,
  1053. 'aligned': True
  1054. }
  1055. with_pool = True
  1056. rpn_target_assign_cfg = {
  1057. 'batch_size_per_im': rpn_batch_size_per_im,
  1058. 'fg_fraction': rpn_fg_fraction,
  1059. 'negative_overlap': .3,
  1060. 'positive_overlap': .7,
  1061. 'use_random': True
  1062. }
  1063. rpn_head = ppdet.modeling.RPNHead(
  1064. anchor_generator=anchor_generator_cfg,
  1065. rpn_target_assign=rpn_target_assign_cfg,
  1066. train_proposal=train_proposal_cfg,
  1067. test_proposal=test_proposal_cfg,
  1068. in_channel=rpn_in_channel)
  1069. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1070. bbox_head = ppdet.modeling.BBoxHead(
  1071. head=head,
  1072. in_channel=head.out_shape[0].channels,
  1073. roi_extractor=roi_extractor_cfg,
  1074. with_pool=with_pool,
  1075. bbox_assigner=bbox_assigner,
  1076. num_classes=num_classes)
  1077. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1078. num_classes=num_classes,
  1079. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1080. nms=ppdet.modeling.MultiClassNMS(
  1081. score_threshold=score_threshold,
  1082. keep_top_k=keep_top_k,
  1083. nms_threshold=nms_threshold))
  1084. params.update({
  1085. 'backbone': backbone,
  1086. 'neck': neck,
  1087. 'rpn_head': rpn_head,
  1088. 'bbox_head': bbox_head,
  1089. 'bbox_post_process': bbox_post_process
  1090. })
  1091. else:
  1092. if backbone not in {'ResNet50', 'ResNet50_vd'}:
  1093. with_fpn = True
  1094. self.with_fpn = with_fpn
  1095. super(FasterRCNN, self).__init__(
  1096. model_name='FasterRCNN', num_classes=num_classes, **params)
  1097. def train(self,
  1098. num_epochs,
  1099. train_dataset,
  1100. train_batch_size=64,
  1101. eval_dataset=None,
  1102. optimizer=None,
  1103. save_interval_epochs=1,
  1104. log_interval_steps=10,
  1105. save_dir='output',
  1106. pretrain_weights='IMAGENET',
  1107. learning_rate=.001,
  1108. warmup_steps=0,
  1109. warmup_start_lr=0.0,
  1110. lr_decay_epochs=(216, 243),
  1111. lr_decay_gamma=0.1,
  1112. metric=None,
  1113. use_ema=False,
  1114. early_stop=False,
  1115. early_stop_patience=5,
  1116. use_vdl=True,
  1117. resume_checkpoint=None):
  1118. """
  1119. Train the model.
  1120. Args:
  1121. num_epochs(int): The number of epochs.
  1122. train_dataset(paddlex.dataset): Training dataset.
  1123. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  1124. eval_dataset(paddlex.dataset, optional):
  1125. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  1126. optimizer(paddle.optimizer.Optimizer or None, optional):
  1127. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  1128. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  1129. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  1130. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  1131. pretrain_weights(str or None, optional):
  1132. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  1133. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  1134. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  1135. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  1136. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  1137. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  1138. metric({'VOC', 'COCO', None}, optional):
  1139. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  1140. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  1141. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  1142. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  1143. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  1144. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  1145. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  1146. `pretrain_weights` can be set simultaneously. Defaults to None.
  1147. """
  1148. if train_dataset.pos_num < len(train_dataset.file_list):
  1149. train_dataset.num_workers = 0
  1150. if train_batch_size != 1:
  1151. train_batch_size = 1
  1152. logging.warning(
  1153. "Training RCNN models with negative samples only support batch size equals to 1 "
  1154. "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
  1155. )
  1156. nranks = paddle.distributed.get_world_size()
  1157. local_rank = paddle.distributed.get_rank()
  1158. # single card training
  1159. if nranks < 2 or local_rank == 0:
  1160. super(FasterRCNN, self).train(
  1161. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1162. optimizer, save_interval_epochs, log_interval_steps,
  1163. save_dir, pretrain_weights, learning_rate, warmup_steps,
  1164. warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
  1165. use_ema, early_stop, early_stop_patience, use_vdl,
  1166. resume_checkpoint)
  1167. else:
  1168. super(FasterRCNN, self).train(
  1169. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1170. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1171. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1172. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1173. early_stop_patience, use_vdl, resume_checkpoint)
  1174. def _compose_batch_transform(self, transforms, mode='train'):
  1175. if mode == 'train':
  1176. default_batch_transforms = [
  1177. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1178. ]
  1179. collate_batch = False
  1180. else:
  1181. default_batch_transforms = [
  1182. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1183. ]
  1184. collate_batch = True
  1185. custom_batch_transforms = []
  1186. for i, op in enumerate(transforms.transforms):
  1187. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1188. if mode != 'train':
  1189. raise Exception(
  1190. "{} cannot be present in the {} transforms. ".format(
  1191. op.__class__.__name__, mode) +
  1192. "Please check the {} transforms.".format(mode))
  1193. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1194. batch_transforms = BatchCompose(
  1195. custom_batch_transforms + default_batch_transforms,
  1196. collate_batch=collate_batch)
  1197. return batch_transforms
  1198. def _fix_transforms_shape(self, image_shape):
  1199. if getattr(self, 'test_transforms', None):
  1200. has_resize_op = False
  1201. resize_op_idx = -1
  1202. normalize_op_idx = len(self.test_transforms.transforms)
  1203. for idx, op in enumerate(self.test_transforms.transforms):
  1204. name = op.__class__.__name__
  1205. if name == 'ResizeByShort':
  1206. has_resize_op = True
  1207. resize_op_idx = idx
  1208. if name == 'Normalize':
  1209. normalize_op_idx = idx
  1210. if not has_resize_op:
  1211. self.test_transforms.transforms.insert(
  1212. normalize_op_idx,
  1213. Resize(
  1214. target_size=image_shape,
  1215. keep_ratio=True,
  1216. interp='CUBIC'))
  1217. else:
  1218. self.test_transforms.transforms[resize_op_idx] = Resize(
  1219. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  1220. self.test_transforms.transforms.append(
  1221. Padding(im_padding_value=[0., 0., 0.]))
  1222. def _get_test_inputs(self, image_shape):
  1223. if image_shape is not None:
  1224. image_shape = self._check_image_shape(image_shape)
  1225. self._fix_transforms_shape(image_shape[-2:])
  1226. else:
  1227. image_shape = [None, 3, -1, -1]
  1228. if self.with_fpn:
  1229. self.test_transforms.transforms.append(
  1230. Padding(im_padding_value=[0., 0., 0.]))
  1231. self.fixed_input_shape = image_shape
  1232. return self._define_input_spec(image_shape)
  1233. class PPYOLO(YOLOv3):
  1234. def __init__(self,
  1235. num_classes=80,
  1236. backbone='ResNet50_vd_dcn',
  1237. anchors=None,
  1238. anchor_masks=None,
  1239. use_coord_conv=True,
  1240. use_iou_aware=True,
  1241. use_spp=True,
  1242. use_drop_block=True,
  1243. scale_x_y=1.05,
  1244. ignore_threshold=0.7,
  1245. label_smooth=False,
  1246. use_iou_loss=True,
  1247. use_matrix_nms=True,
  1248. nms_score_threshold=0.01,
  1249. nms_topk=-1,
  1250. nms_keep_topk=100,
  1251. nms_iou_threshold=0.45,
  1252. **params):
  1253. self.init_params = locals()
  1254. if backbone not in {
  1255. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1256. 'MobileNetV3_small'
  1257. }:
  1258. raise ValueError(
  1259. "backbone: {} is not supported. Please choose one of "
  1260. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  1261. format(backbone))
  1262. self.backbone_name = backbone
  1263. self.downsample_ratios = [
  1264. 32, 16, 8
  1265. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1266. if params.get('with_net', True):
  1267. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1268. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1269. norm_type = 'sync_bn'
  1270. else:
  1271. norm_type = 'bn'
  1272. if anchors is None and anchor_masks is None:
  1273. if 'MobileNetV3' in backbone:
  1274. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1275. [120, 195], [254, 235]]
  1276. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1277. elif backbone == 'ResNet50_vd_dcn':
  1278. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1279. [62, 45], [59, 119], [116, 90], [156, 198],
  1280. [373, 326]]
  1281. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1282. else:
  1283. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1284. [135, 169], [344, 319]]
  1285. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1286. elif anchors is None or anchor_masks is None:
  1287. raise ValueError(
  1288. "Please define both anchors and anchor_masks.")
  1289. if backbone == 'ResNet50_vd_dcn':
  1290. backbone = self._get_backbone(
  1291. 'ResNet',
  1292. variant='d',
  1293. norm_type=norm_type,
  1294. return_idx=[1, 2, 3],
  1295. dcn_v2_stages=[3],
  1296. freeze_at=-1,
  1297. freeze_norm=False,
  1298. norm_decay=0.)
  1299. elif backbone == 'ResNet18_vd':
  1300. backbone = self._get_backbone(
  1301. 'ResNet',
  1302. depth=18,
  1303. variant='d',
  1304. norm_type=norm_type,
  1305. return_idx=[2, 3],
  1306. freeze_at=-1,
  1307. freeze_norm=False,
  1308. norm_decay=0.)
  1309. elif backbone == 'MobileNetV3_large':
  1310. backbone = self._get_backbone(
  1311. 'MobileNetV3',
  1312. model_name='large',
  1313. norm_type=norm_type,
  1314. scale=1,
  1315. with_extra_blocks=False,
  1316. extra_block_filters=[],
  1317. feature_maps=[13, 16])
  1318. elif backbone == 'MobileNetV3_small':
  1319. backbone = self._get_backbone(
  1320. 'MobileNetV3',
  1321. model_name='small',
  1322. norm_type=norm_type,
  1323. scale=1,
  1324. with_extra_blocks=False,
  1325. extra_block_filters=[],
  1326. feature_maps=[9, 12])
  1327. neck = ppdet.modeling.PPYOLOFPN(
  1328. norm_type=norm_type,
  1329. in_channels=[i.channels for i in backbone.out_shape],
  1330. coord_conv=use_coord_conv,
  1331. drop_block=use_drop_block,
  1332. spp=use_spp,
  1333. conv_block_num=0
  1334. if ('MobileNetV3' in self.backbone_name or
  1335. self.backbone_name == 'ResNet18_vd') else 2)
  1336. loss = ppdet.modeling.YOLOv3Loss(
  1337. num_classes=num_classes,
  1338. ignore_thresh=ignore_threshold,
  1339. downsample=self.downsample_ratios,
  1340. label_smooth=label_smooth,
  1341. scale_x_y=scale_x_y,
  1342. iou_loss=ppdet.modeling.IouLoss(
  1343. loss_weight=2.5, loss_square=True)
  1344. if use_iou_loss else None,
  1345. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1346. if use_iou_aware else None)
  1347. yolo_head = ppdet.modeling.YOLOv3Head(
  1348. in_channels=[i.channels for i in neck.out_shape],
  1349. anchors=anchors,
  1350. anchor_masks=anchor_masks,
  1351. num_classes=num_classes,
  1352. loss=loss,
  1353. iou_aware=use_iou_aware)
  1354. if use_matrix_nms:
  1355. nms = ppdet.modeling.MatrixNMS(
  1356. keep_top_k=nms_keep_topk,
  1357. score_threshold=nms_score_threshold,
  1358. post_threshold=.05
  1359. if 'MobileNetV3' in self.backbone_name else .01,
  1360. nms_top_k=nms_topk,
  1361. background_label=-1)
  1362. else:
  1363. nms = ppdet.modeling.MultiClassNMS(
  1364. score_threshold=nms_score_threshold,
  1365. nms_top_k=nms_topk,
  1366. keep_top_k=nms_keep_topk,
  1367. nms_threshold=nms_iou_threshold)
  1368. post_process = ppdet.modeling.BBoxPostProcess(
  1369. decode=ppdet.modeling.YOLOBox(
  1370. num_classes=num_classes,
  1371. conf_thresh=.005
  1372. if 'MobileNetV3' in self.backbone_name else .01,
  1373. scale_x_y=scale_x_y),
  1374. nms=nms)
  1375. params.update({
  1376. 'backbone': backbone,
  1377. 'neck': neck,
  1378. 'yolo_head': yolo_head,
  1379. 'post_process': post_process
  1380. })
  1381. super(YOLOv3, self).__init__(
  1382. model_name='YOLOv3', num_classes=num_classes, **params)
  1383. self.anchors = anchors
  1384. self.anchor_masks = anchor_masks
  1385. self.model_name = 'PPYOLO'
  1386. def _get_test_inputs(self, image_shape):
  1387. if image_shape is not None:
  1388. image_shape = self._check_image_shape(image_shape)
  1389. self._fix_transforms_shape(image_shape[-2:])
  1390. else:
  1391. image_shape = [None, 3, 608, 608]
  1392. if getattr(self, 'test_transforms', None):
  1393. for idx, op in enumerate(self.test_transforms.transforms):
  1394. name = op.__class__.__name__
  1395. if name == 'Resize':
  1396. image_shape = [None, 3] + list(
  1397. self.test_transforms.transforms[idx].target_size)
  1398. logging.warning(
  1399. '[Important!!!] When exporting inference model for {}, '
  1400. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  1401. 'Please ensure image shape after transforms is {}, if not, '
  1402. 'fixed_input_shape should be specified manually.'
  1403. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  1404. self.fixed_input_shape = image_shape
  1405. return self._define_input_spec(image_shape)
  1406. class PPYOLOTiny(YOLOv3):
  1407. def __init__(self,
  1408. num_classes=80,
  1409. backbone='MobileNetV3',
  1410. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1411. [60, 170], [220, 125], [128, 222], [264, 266]],
  1412. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1413. use_iou_aware=False,
  1414. use_spp=True,
  1415. use_drop_block=True,
  1416. scale_x_y=1.05,
  1417. ignore_threshold=0.5,
  1418. label_smooth=False,
  1419. use_iou_loss=True,
  1420. use_matrix_nms=False,
  1421. nms_score_threshold=0.005,
  1422. nms_topk=1000,
  1423. nms_keep_topk=100,
  1424. nms_iou_threshold=0.45,
  1425. **params):
  1426. self.init_params = locals()
  1427. if backbone != 'MobileNetV3':
  1428. logging.warning(
  1429. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1430. "Backbone is forcibly set to MobileNetV3.")
  1431. self.backbone_name = 'MobileNetV3'
  1432. self.downsample_ratios = [32, 16, 8]
  1433. if params.get('with_net', True):
  1434. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1435. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1436. norm_type = 'sync_bn'
  1437. else:
  1438. norm_type = 'bn'
  1439. backbone = self._get_backbone(
  1440. 'MobileNetV3',
  1441. model_name='large',
  1442. norm_type=norm_type,
  1443. scale=.5,
  1444. with_extra_blocks=False,
  1445. extra_block_filters=[],
  1446. feature_maps=[7, 13, 16])
  1447. neck = ppdet.modeling.PPYOLOTinyFPN(
  1448. detection_block_channels=[160, 128, 96],
  1449. in_channels=[i.channels for i in backbone.out_shape],
  1450. spp=use_spp,
  1451. drop_block=use_drop_block)
  1452. loss = ppdet.modeling.YOLOv3Loss(
  1453. num_classes=num_classes,
  1454. ignore_thresh=ignore_threshold,
  1455. downsample=self.downsample_ratios,
  1456. label_smooth=label_smooth,
  1457. scale_x_y=scale_x_y,
  1458. iou_loss=ppdet.modeling.IouLoss(
  1459. loss_weight=2.5, loss_square=True)
  1460. if use_iou_loss else None,
  1461. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1462. if use_iou_aware else None)
  1463. yolo_head = ppdet.modeling.YOLOv3Head(
  1464. in_channels=[i.channels for i in neck.out_shape],
  1465. anchors=anchors,
  1466. anchor_masks=anchor_masks,
  1467. num_classes=num_classes,
  1468. loss=loss,
  1469. iou_aware=use_iou_aware)
  1470. if use_matrix_nms:
  1471. nms = ppdet.modeling.MatrixNMS(
  1472. keep_top_k=nms_keep_topk,
  1473. score_threshold=nms_score_threshold,
  1474. post_threshold=.05,
  1475. nms_top_k=nms_topk,
  1476. background_label=-1)
  1477. else:
  1478. nms = ppdet.modeling.MultiClassNMS(
  1479. score_threshold=nms_score_threshold,
  1480. nms_top_k=nms_topk,
  1481. keep_top_k=nms_keep_topk,
  1482. nms_threshold=nms_iou_threshold)
  1483. post_process = ppdet.modeling.BBoxPostProcess(
  1484. decode=ppdet.modeling.YOLOBox(
  1485. num_classes=num_classes,
  1486. conf_thresh=.005,
  1487. downsample_ratio=32,
  1488. clip_bbox=True,
  1489. scale_x_y=scale_x_y),
  1490. nms=nms)
  1491. params.update({
  1492. 'backbone': backbone,
  1493. 'neck': neck,
  1494. 'yolo_head': yolo_head,
  1495. 'post_process': post_process
  1496. })
  1497. super(YOLOv3, self).__init__(
  1498. model_name='YOLOv3', num_classes=num_classes, **params)
  1499. self.anchors = anchors
  1500. self.anchor_masks = anchor_masks
  1501. self.model_name = 'PPYOLOTiny'
  1502. def _get_test_inputs(self, image_shape):
  1503. if image_shape is not None:
  1504. image_shape = self._check_image_shape(image_shape)
  1505. self._fix_transforms_shape(image_shape[-2:])
  1506. else:
  1507. image_shape = [None, 3, 320, 320]
  1508. if getattr(self, 'test_transforms', None):
  1509. for idx, op in enumerate(self.test_transforms.transforms):
  1510. name = op.__class__.__name__
  1511. if name == 'Resize':
  1512. image_shape = [None, 3] + list(
  1513. self.test_transforms.transforms[idx].target_size)
  1514. logging.warning(
  1515. '[Important!!!] When exporting inference model for {},'.format(
  1516. self.__class__.__name__) +
  1517. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1518. format(image_shape) +
  1519. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1520. format(image_shape[1:]) + 'should be specified manually.')
  1521. self.fixed_input_shape = image_shape
  1522. return self._define_input_spec(image_shape)
  1523. class PPYOLOv2(YOLOv3):
  1524. def __init__(self,
  1525. num_classes=80,
  1526. backbone='ResNet50_vd_dcn',
  1527. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1528. [59, 119], [116, 90], [156, 198], [373, 326]],
  1529. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1530. use_iou_aware=True,
  1531. use_spp=True,
  1532. use_drop_block=True,
  1533. scale_x_y=1.05,
  1534. ignore_threshold=0.7,
  1535. label_smooth=False,
  1536. use_iou_loss=True,
  1537. use_matrix_nms=True,
  1538. nms_score_threshold=0.01,
  1539. nms_topk=-1,
  1540. nms_keep_topk=100,
  1541. nms_iou_threshold=0.45,
  1542. **params):
  1543. self.init_params = locals()
  1544. if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
  1545. raise ValueError(
  1546. "backbone: {} is not supported. Please choose one of "
  1547. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1548. self.backbone_name = backbone
  1549. self.downsample_ratios = [32, 16, 8]
  1550. if params.get('with_net', True):
  1551. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1552. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1553. norm_type = 'sync_bn'
  1554. else:
  1555. norm_type = 'bn'
  1556. if backbone == 'ResNet50_vd_dcn':
  1557. backbone = self._get_backbone(
  1558. 'ResNet',
  1559. variant='d',
  1560. norm_type=norm_type,
  1561. return_idx=[1, 2, 3],
  1562. dcn_v2_stages=[3],
  1563. freeze_at=-1,
  1564. freeze_norm=False,
  1565. norm_decay=0.)
  1566. elif backbone == 'ResNet101_vd_dcn':
  1567. backbone = self._get_backbone(
  1568. 'ResNet',
  1569. depth=101,
  1570. variant='d',
  1571. norm_type=norm_type,
  1572. return_idx=[1, 2, 3],
  1573. dcn_v2_stages=[3],
  1574. freeze_at=-1,
  1575. freeze_norm=False,
  1576. norm_decay=0.)
  1577. neck = ppdet.modeling.PPYOLOPAN(
  1578. norm_type=norm_type,
  1579. in_channels=[i.channels for i in backbone.out_shape],
  1580. drop_block=use_drop_block,
  1581. block_size=3,
  1582. keep_prob=.9,
  1583. spp=use_spp)
  1584. loss = ppdet.modeling.YOLOv3Loss(
  1585. num_classes=num_classes,
  1586. ignore_thresh=ignore_threshold,
  1587. downsample=self.downsample_ratios,
  1588. label_smooth=label_smooth,
  1589. scale_x_y=scale_x_y,
  1590. iou_loss=ppdet.modeling.IouLoss(
  1591. loss_weight=2.5, loss_square=True)
  1592. if use_iou_loss else None,
  1593. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1594. if use_iou_aware else None)
  1595. yolo_head = ppdet.modeling.YOLOv3Head(
  1596. in_channels=[i.channels for i in neck.out_shape],
  1597. anchors=anchors,
  1598. anchor_masks=anchor_masks,
  1599. num_classes=num_classes,
  1600. loss=loss,
  1601. iou_aware=use_iou_aware,
  1602. iou_aware_factor=.5)
  1603. if use_matrix_nms:
  1604. nms = ppdet.modeling.MatrixNMS(
  1605. keep_top_k=nms_keep_topk,
  1606. score_threshold=nms_score_threshold,
  1607. post_threshold=.01,
  1608. nms_top_k=nms_topk,
  1609. background_label=-1)
  1610. else:
  1611. nms = ppdet.modeling.MultiClassNMS(
  1612. score_threshold=nms_score_threshold,
  1613. nms_top_k=nms_topk,
  1614. keep_top_k=nms_keep_topk,
  1615. nms_threshold=nms_iou_threshold)
  1616. post_process = ppdet.modeling.BBoxPostProcess(
  1617. decode=ppdet.modeling.YOLOBox(
  1618. num_classes=num_classes,
  1619. conf_thresh=.01,
  1620. downsample_ratio=32,
  1621. clip_bbox=True,
  1622. scale_x_y=scale_x_y),
  1623. nms=nms)
  1624. params.update({
  1625. 'backbone': backbone,
  1626. 'neck': neck,
  1627. 'yolo_head': yolo_head,
  1628. 'post_process': post_process
  1629. })
  1630. super(YOLOv3, self).__init__(
  1631. model_name='YOLOv3', num_classes=num_classes, **params)
  1632. self.anchors = anchors
  1633. self.anchor_masks = anchor_masks
  1634. self.model_name = 'PPYOLOv2'
  1635. def _get_test_inputs(self, image_shape):
  1636. if image_shape is not None:
  1637. image_shape = self._check_image_shape(image_shape)
  1638. self._fix_transforms_shape(image_shape[-2:])
  1639. else:
  1640. image_shape = [None, 3, 640, 640]
  1641. if getattr(self, 'test_transforms', None):
  1642. for idx, op in enumerate(self.test_transforms.transforms):
  1643. name = op.__class__.__name__
  1644. if name == 'Resize':
  1645. image_shape = [None, 3] + list(
  1646. self.test_transforms.transforms[idx].target_size)
  1647. logging.warning(
  1648. '[Important!!!] When exporting inference model for {},'.format(
  1649. self.__class__.__name__) +
  1650. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1651. format(image_shape) +
  1652. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1653. format(image_shape[1:]) + 'should be specified manually.')
  1654. self.fixed_input_shape = image_shape
  1655. return self._define_input_spec(image_shape)
  1656. class MaskRCNN(BaseDetector):
  1657. def __init__(self,
  1658. num_classes=80,
  1659. backbone='ResNet50_vd',
  1660. with_fpn=True,
  1661. with_dcn=False,
  1662. aspect_ratios=[0.5, 1.0, 2.0],
  1663. anchor_sizes=[[32], [64], [128], [256], [512]],
  1664. keep_top_k=100,
  1665. nms_threshold=0.5,
  1666. score_threshold=0.05,
  1667. fpn_num_channels=256,
  1668. rpn_batch_size_per_im=256,
  1669. rpn_fg_fraction=0.5,
  1670. test_pre_nms_top_n=None,
  1671. test_post_nms_top_n=1000,
  1672. **params):
  1673. self.init_params = locals()
  1674. if backbone not in {
  1675. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1676. 'ResNet101_vd'
  1677. }:
  1678. raise ValueError(
  1679. "backbone: {} is not supported. Please choose one of "
  1680. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1681. format(backbone))
  1682. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1683. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1684. if params.get('with_net', True):
  1685. if backbone == 'ResNet50':
  1686. if with_fpn:
  1687. backbone = self._get_backbone(
  1688. 'ResNet',
  1689. norm_type='bn',
  1690. freeze_at=0,
  1691. return_idx=[0, 1, 2, 3],
  1692. num_stages=4,
  1693. dcn_v2_stages=dcn_v2_stages)
  1694. else:
  1695. if with_dcn:
  1696. logging.warning(
  1697. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1698. format(backbone))
  1699. backbone = self._get_backbone(
  1700. 'ResNet',
  1701. norm_type='bn',
  1702. freeze_at=0,
  1703. return_idx=[2],
  1704. num_stages=3)
  1705. elif 'ResNet50_vd' in backbone:
  1706. if not with_fpn:
  1707. logging.warning(
  1708. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1709. format(backbone))
  1710. with_fpn = True
  1711. backbone = self._get_backbone(
  1712. 'ResNet',
  1713. variant='d',
  1714. norm_type='bn',
  1715. freeze_at=0,
  1716. return_idx=[0, 1, 2, 3],
  1717. num_stages=4,
  1718. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1719. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1720. dcn_v2_stages=dcn_v2_stages)
  1721. else:
  1722. if not with_fpn:
  1723. logging.warning(
  1724. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1725. format(backbone))
  1726. with_fpn = True
  1727. backbone = self._get_backbone(
  1728. 'ResNet',
  1729. variant='d' if '_vd' in backbone else 'b',
  1730. depth=101,
  1731. norm_type='bn',
  1732. freeze_at=0,
  1733. return_idx=[0, 1, 2, 3],
  1734. num_stages=4,
  1735. dcn_v2_stages=dcn_v2_stages)
  1736. rpn_in_channel = backbone.out_shape[0].channels
  1737. if with_fpn:
  1738. neck = ppdet.modeling.FPN(
  1739. in_channels=[i.channels for i in backbone.out_shape],
  1740. out_channel=fpn_num_channels,
  1741. spatial_scales=[
  1742. 1.0 / i.stride for i in backbone.out_shape
  1743. ])
  1744. rpn_in_channel = neck.out_shape[0].channels
  1745. anchor_generator_cfg = {
  1746. 'aspect_ratios': aspect_ratios,
  1747. 'anchor_sizes': anchor_sizes,
  1748. 'strides': [4, 8, 16, 32, 64]
  1749. }
  1750. train_proposal_cfg = {
  1751. 'min_size': 0.0,
  1752. 'nms_thresh': .7,
  1753. 'pre_nms_top_n': 2000,
  1754. 'post_nms_top_n': 1000,
  1755. 'topk_after_collect': True
  1756. }
  1757. test_proposal_cfg = {
  1758. 'min_size': 0.0,
  1759. 'nms_thresh': .7,
  1760. 'pre_nms_top_n': 1000
  1761. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1762. 'post_nms_top_n': test_post_nms_top_n
  1763. }
  1764. bb_head = ppdet.modeling.TwoFCHead(
  1765. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1766. bb_roi_extractor_cfg = {
  1767. 'resolution': 7,
  1768. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1769. 'sampling_ratio': 0,
  1770. 'aligned': True
  1771. }
  1772. with_pool = False
  1773. m_head = ppdet.modeling.MaskFeat(
  1774. in_channel=neck.out_shape[0].channels,
  1775. out_channel=256,
  1776. num_convs=4)
  1777. m_roi_extractor_cfg = {
  1778. 'resolution': 14,
  1779. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1780. 'sampling_ratio': 0,
  1781. 'aligned': True
  1782. }
  1783. mask_assigner = MaskAssigner(
  1784. num_classes=num_classes, mask_resolution=28)
  1785. share_bbox_feat = False
  1786. else:
  1787. neck = None
  1788. anchor_generator_cfg = {
  1789. 'aspect_ratios': aspect_ratios,
  1790. 'anchor_sizes': anchor_sizes,
  1791. 'strides': [16]
  1792. }
  1793. train_proposal_cfg = {
  1794. 'min_size': 0.0,
  1795. 'nms_thresh': .7,
  1796. 'pre_nms_top_n': 12000,
  1797. 'post_nms_top_n': 2000,
  1798. 'topk_after_collect': False
  1799. }
  1800. test_proposal_cfg = {
  1801. 'min_size': 0.0,
  1802. 'nms_thresh': .7,
  1803. 'pre_nms_top_n': 6000
  1804. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1805. 'post_nms_top_n': test_post_nms_top_n
  1806. }
  1807. bb_head = ppdet.modeling.Res5Head()
  1808. bb_roi_extractor_cfg = {
  1809. 'resolution': 14,
  1810. 'spatial_scale':
  1811. [1. / i.stride for i in backbone.out_shape],
  1812. 'sampling_ratio': 0,
  1813. 'aligned': True
  1814. }
  1815. with_pool = True
  1816. m_head = ppdet.modeling.MaskFeat(
  1817. in_channel=bb_head.out_shape[0].channels,
  1818. out_channel=256,
  1819. num_convs=0)
  1820. m_roi_extractor_cfg = {
  1821. 'resolution': 14,
  1822. 'spatial_scale':
  1823. [1. / i.stride for i in backbone.out_shape],
  1824. 'sampling_ratio': 0,
  1825. 'aligned': True
  1826. }
  1827. mask_assigner = MaskAssigner(
  1828. num_classes=num_classes, mask_resolution=14)
  1829. share_bbox_feat = True
  1830. rpn_target_assign_cfg = {
  1831. 'batch_size_per_im': rpn_batch_size_per_im,
  1832. 'fg_fraction': rpn_fg_fraction,
  1833. 'negative_overlap': .3,
  1834. 'positive_overlap': .7,
  1835. 'use_random': True
  1836. }
  1837. rpn_head = ppdet.modeling.RPNHead(
  1838. anchor_generator=anchor_generator_cfg,
  1839. rpn_target_assign=rpn_target_assign_cfg,
  1840. train_proposal=train_proposal_cfg,
  1841. test_proposal=test_proposal_cfg,
  1842. in_channel=rpn_in_channel)
  1843. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1844. bbox_head = ppdet.modeling.BBoxHead(
  1845. head=bb_head,
  1846. in_channel=bb_head.out_shape[0].channels,
  1847. roi_extractor=bb_roi_extractor_cfg,
  1848. with_pool=with_pool,
  1849. bbox_assigner=bbox_assigner,
  1850. num_classes=num_classes)
  1851. mask_head = ppdet.modeling.MaskHead(
  1852. head=m_head,
  1853. roi_extractor=m_roi_extractor_cfg,
  1854. mask_assigner=mask_assigner,
  1855. share_bbox_feat=share_bbox_feat,
  1856. num_classes=num_classes)
  1857. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1858. num_classes=num_classes,
  1859. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1860. nms=ppdet.modeling.MultiClassNMS(
  1861. score_threshold=score_threshold,
  1862. keep_top_k=keep_top_k,
  1863. nms_threshold=nms_threshold))
  1864. mask_post_process = ppdet.modeling.MaskPostProcess(
  1865. binary_thresh=.5)
  1866. params.update({
  1867. 'backbone': backbone,
  1868. 'neck': neck,
  1869. 'rpn_head': rpn_head,
  1870. 'bbox_head': bbox_head,
  1871. 'mask_head': mask_head,
  1872. 'bbox_post_process': bbox_post_process,
  1873. 'mask_post_process': mask_post_process
  1874. })
  1875. self.with_fpn = with_fpn
  1876. super(MaskRCNN, self).__init__(
  1877. model_name='MaskRCNN', num_classes=num_classes, **params)
  1878. def train(self,
  1879. num_epochs,
  1880. train_dataset,
  1881. train_batch_size=64,
  1882. eval_dataset=None,
  1883. optimizer=None,
  1884. save_interval_epochs=1,
  1885. log_interval_steps=10,
  1886. save_dir='output',
  1887. pretrain_weights='IMAGENET',
  1888. learning_rate=.001,
  1889. warmup_steps=0,
  1890. warmup_start_lr=0.0,
  1891. lr_decay_epochs=(216, 243),
  1892. lr_decay_gamma=0.1,
  1893. metric=None,
  1894. use_ema=False,
  1895. early_stop=False,
  1896. early_stop_patience=5,
  1897. use_vdl=True,
  1898. resume_checkpoint=None):
  1899. """
  1900. Train the model.
  1901. Args:
  1902. num_epochs(int): The number of epochs.
  1903. train_dataset(paddlex.dataset): Training dataset.
  1904. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  1905. eval_dataset(paddlex.dataset, optional):
  1906. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  1907. optimizer(paddle.optimizer.Optimizer or None, optional):
  1908. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  1909. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  1910. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  1911. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  1912. pretrain_weights(str or None, optional):
  1913. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  1914. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  1915. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  1916. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  1917. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  1918. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  1919. metric({'VOC', 'COCO', None}, optional):
  1920. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  1921. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  1922. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  1923. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  1924. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  1925. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  1926. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  1927. `pretrain_weights` can be set simultaneously. Defaults to None.
  1928. """
  1929. if train_dataset.pos_num < len(train_dataset.file_list):
  1930. train_dataset.num_workers = 0
  1931. if train_batch_size != 1:
  1932. train_batch_size = 1
  1933. logging.warning(
  1934. "Training RCNN models with negative samples only support batch size equals to 1 "
  1935. "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
  1936. )
  1937. nranks = paddle.distributed.get_world_size()
  1938. local_rank = paddle.distributed.get_rank()
  1939. # single card training
  1940. if nranks < 2 or local_rank == 0:
  1941. super(MaskRCNN, self).train(
  1942. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1943. optimizer, save_interval_epochs, log_interval_steps,
  1944. save_dir, pretrain_weights, learning_rate, warmup_steps,
  1945. warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
  1946. use_ema, early_stop, early_stop_patience, use_vdl,
  1947. resume_checkpoint)
  1948. else:
  1949. super(MaskRCNN, self).train(
  1950. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1951. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1952. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1953. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1954. early_stop_patience, use_vdl, resume_checkpoint)
  1955. def _compose_batch_transform(self, transforms, mode='train'):
  1956. if mode == 'train':
  1957. default_batch_transforms = [
  1958. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1959. ]
  1960. collate_batch = False
  1961. else:
  1962. default_batch_transforms = [
  1963. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1964. ]
  1965. collate_batch = True
  1966. custom_batch_transforms = []
  1967. for i, op in enumerate(transforms.transforms):
  1968. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1969. if mode != 'train':
  1970. raise Exception(
  1971. "{} cannot be present in the {} transforms. ".format(
  1972. op.__class__.__name__, mode) +
  1973. "Please check the {} transforms.".format(mode))
  1974. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1975. batch_transforms = BatchCompose(
  1976. custom_batch_transforms + default_batch_transforms,
  1977. collate_batch=collate_batch)
  1978. return batch_transforms
  1979. def _fix_transforms_shape(self, image_shape):
  1980. if getattr(self, 'test_transforms', None):
  1981. has_resize_op = False
  1982. resize_op_idx = -1
  1983. normalize_op_idx = len(self.test_transforms.transforms)
  1984. for idx, op in enumerate(self.test_transforms.transforms):
  1985. name = op.__class__.__name__
  1986. if name == 'ResizeByShort':
  1987. has_resize_op = True
  1988. resize_op_idx = idx
  1989. if name == 'Normalize':
  1990. normalize_op_idx = idx
  1991. if not has_resize_op:
  1992. self.test_transforms.transforms.insert(
  1993. normalize_op_idx,
  1994. Resize(
  1995. target_size=image_shape,
  1996. keep_ratio=True,
  1997. interp='CUBIC'))
  1998. else:
  1999. self.test_transforms.transforms[resize_op_idx] = Resize(
  2000. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  2001. self.test_transforms.transforms.append(
  2002. Padding(im_padding_value=[0., 0., 0.]))
  2003. def _get_test_inputs(self, image_shape):
  2004. if image_shape is not None:
  2005. image_shape = self._check_image_shape(image_shape)
  2006. self._fix_transforms_shape(image_shape[-2:])
  2007. else:
  2008. image_shape = [None, 3, -1, -1]
  2009. if self.with_fpn:
  2010. self.test_transforms.transforms.append(
  2011. Padding(im_padding_value=[0., 0., 0.]))
  2012. self.fixed_input_shape = image_shape
  2013. return self._define_input_spec(image_shape)