yolo_v3.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import fluid
  15. from paddle.fluid.param_attr import ParamAttr
  16. from paddle.fluid.regularizer import L2Decay
  17. from collections import OrderedDict
  18. class YOLOv3:
  19. def __init__(self,
  20. backbone,
  21. num_classes,
  22. mode='train',
  23. anchors=None,
  24. anchor_masks=None,
  25. ignore_threshold=0.7,
  26. label_smooth=False,
  27. nms_score_threshold=0.01,
  28. nms_topk=1000,
  29. nms_keep_topk=100,
  30. nms_iou_threshold=0.45,
  31. train_random_shapes=[
  32. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  33. ]):
  34. if anchors is None:
  35. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  36. [59, 119], [116, 90], [156, 198], [373, 326]]
  37. if anchor_masks is None:
  38. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  39. self.anchors = anchors
  40. self.anchor_masks = anchor_masks
  41. self._parse_anchors(anchors)
  42. self.mode = mode
  43. self.num_classes = num_classes
  44. self.backbone = backbone
  45. self.ignore_thresh = ignore_threshold
  46. self.label_smooth = label_smooth
  47. self.nms_score_threshold = nms_score_threshold
  48. self.nms_topk = nms_topk
  49. self.nms_keep_topk = nms_keep_topk
  50. self.nms_iou_threshold = nms_iou_threshold
  51. self.norm_decay = 0.0
  52. self.prefix_name = ''
  53. self.train_random_shapes = train_random_shapes
  54. def _head(self, feats):
  55. outputs = []
  56. out_layer_num = len(self.anchor_masks)
  57. blocks = feats[-1:-out_layer_num - 1:-1]
  58. route = None
  59. for i, block in enumerate(blocks):
  60. if i > 0:
  61. block = fluid.layers.concat(input=[route, block], axis=1)
  62. route, tip = self._detection_block(
  63. block,
  64. channel=512 // (2**i),
  65. name=self.prefix_name + 'yolo_block.{}'.format(i))
  66. num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
  67. block_out = fluid.layers.conv2d(
  68. input=tip,
  69. num_filters=num_filters,
  70. filter_size=1,
  71. stride=1,
  72. padding=0,
  73. act=None,
  74. param_attr=ParamAttr(name=self.prefix_name +
  75. 'yolo_output.{}.conv.weights'.format(i)),
  76. bias_attr=ParamAttr(
  77. regularizer=L2Decay(0.0),
  78. name=self.prefix_name +
  79. 'yolo_output.{}.conv.bias'.format(i)))
  80. outputs.append(block_out)
  81. if i < len(blocks) - 1:
  82. route = self._conv_bn(
  83. input=route,
  84. ch_out=256 // (2**i),
  85. filter_size=1,
  86. stride=1,
  87. padding=0,
  88. name=self.prefix_name + 'yolo_transition.{}'.format(i))
  89. route = self._upsample(route)
  90. return outputs
  91. def _parse_anchors(self, anchors):
  92. self.anchors = []
  93. self.mask_anchors = []
  94. assert len(anchors) > 0, "ANCHORS not set."
  95. assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
  96. for anchor in anchors:
  97. assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
  98. self.anchors.extend(anchor)
  99. anchor_num = len(anchors)
  100. for masks in self.anchor_masks:
  101. self.mask_anchors.append([])
  102. for mask in masks:
  103. assert mask < anchor_num, "anchor mask index overflow"
  104. self.mask_anchors[-1].extend(anchors[mask])
  105. def _conv_bn(self,
  106. input,
  107. ch_out,
  108. filter_size,
  109. stride,
  110. padding,
  111. act='leaky',
  112. is_test=False,
  113. name=None):
  114. conv = fluid.layers.conv2d(
  115. input=input,
  116. num_filters=ch_out,
  117. filter_size=filter_size,
  118. stride=stride,
  119. padding=padding,
  120. act=None,
  121. param_attr=ParamAttr(name=name + '.conv.weights'),
  122. bias_attr=False)
  123. bn_name = name + '.bn'
  124. bn_param_attr = ParamAttr(
  125. regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
  126. bn_bias_attr = ParamAttr(
  127. regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
  128. out = fluid.layers.batch_norm(
  129. input=conv,
  130. act=None,
  131. is_test=is_test,
  132. param_attr=bn_param_attr,
  133. bias_attr=bn_bias_attr,
  134. moving_mean_name=bn_name + '.mean',
  135. moving_variance_name=bn_name + '.var')
  136. if act == 'leaky':
  137. out = fluid.layers.leaky_relu(x=out, alpha=0.1)
  138. return out
  139. def _upsample(self, input, scale=2, name=None):
  140. out = fluid.layers.resize_nearest(
  141. input=input, scale=float(scale), name=name)
  142. return out
  143. def _detection_block(self, input, channel, name=None):
  144. assert channel % 2 == 0, "channel({}) cannot be divided by 2 in detection block({})".format(
  145. channel, name)
  146. is_test = False if self.mode == 'train' else True
  147. conv = input
  148. for i in range(2):
  149. conv = self._conv_bn(
  150. conv,
  151. channel,
  152. filter_size=1,
  153. stride=1,
  154. padding=0,
  155. is_test=is_test,
  156. name='{}.{}.0'.format(name, i))
  157. conv = self._conv_bn(
  158. conv,
  159. channel * 2,
  160. filter_size=3,
  161. stride=1,
  162. padding=1,
  163. is_test=is_test,
  164. name='{}.{}.1'.format(name, i))
  165. route = self._conv_bn(
  166. conv,
  167. channel,
  168. filter_size=1,
  169. stride=1,
  170. padding=0,
  171. is_test=is_test,
  172. name='{}.2'.format(name))
  173. tip = self._conv_bn(
  174. route,
  175. channel * 2,
  176. filter_size=3,
  177. stride=1,
  178. padding=1,
  179. is_test=is_test,
  180. name='{}.tip'.format(name))
  181. return route, tip
  182. def _get_loss(self, inputs, gt_box, gt_label, gt_score):
  183. losses = []
  184. downsample = 32
  185. for i, input in enumerate(inputs):
  186. loss = fluid.layers.yolov3_loss(
  187. x=input,
  188. gt_box=gt_box,
  189. gt_label=gt_label,
  190. gt_score=gt_score,
  191. anchors=self.anchors,
  192. anchor_mask=self.anchor_masks[i],
  193. class_num=self.num_classes,
  194. ignore_thresh=self.ignore_thresh,
  195. downsample_ratio=downsample,
  196. use_label_smooth=self.label_smooth,
  197. name=self.prefix_name + 'yolo_loss' + str(i))
  198. losses.append(fluid.layers.reduce_mean(loss))
  199. downsample //= 2
  200. return sum(losses)
  201. def _get_prediction(self, inputs, im_size):
  202. boxes = []
  203. scores = []
  204. downsample = 32
  205. for i, input in enumerate(inputs):
  206. box, score = fluid.layers.yolo_box(
  207. x=input,
  208. img_size=im_size,
  209. anchors=self.mask_anchors[i],
  210. class_num=self.num_classes,
  211. conf_thresh=self.nms_score_threshold,
  212. downsample_ratio=downsample,
  213. name=self.prefix_name + 'yolo_box' + str(i))
  214. boxes.append(box)
  215. scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
  216. downsample //= 2
  217. yolo_boxes = fluid.layers.concat(boxes, axis=1)
  218. yolo_scores = fluid.layers.concat(scores, axis=2)
  219. pred = fluid.layers.multiclass_nms(
  220. bboxes=yolo_boxes,
  221. scores=yolo_scores,
  222. score_threshold=self.nms_score_threshold,
  223. nms_top_k=self.nms_topk,
  224. keep_top_k=self.nms_keep_topk,
  225. nms_threshold=self.nms_iou_threshold,
  226. normalized=False,
  227. nms_eta=1.0,
  228. background_label=-1)
  229. return pred
  230. def generate_inputs(self):
  231. inputs = OrderedDict()
  232. inputs['image'] = fluid.data(
  233. dtype='float32', shape=[None, 3, None, None], name='image')
  234. if self.mode == 'train':
  235. inputs['gt_box'] = fluid.data(
  236. dtype='float32', shape=[None, None, 4], name='gt_box')
  237. inputs['gt_label'] = fluid.data(
  238. dtype='int32', shape=[None, None], name='gt_label')
  239. inputs['gt_score'] = fluid.data(
  240. dtype='float32', shape=[None, None], name='gt_score')
  241. inputs['im_size'] = fluid.data(
  242. dtype='int32', shape=[None, 2], name='im_size')
  243. elif self.mode == 'eval':
  244. inputs['im_size'] = fluid.data(
  245. dtype='int32', shape=[None, 2], name='im_size')
  246. inputs['im_id'] = fluid.data(
  247. dtype='int32', shape=[None, 1], name='im_id')
  248. inputs['gt_box'] = fluid.data(
  249. dtype='float32', shape=[None, None, 4], name='gt_box')
  250. inputs['gt_label'] = fluid.data(
  251. dtype='int32', shape=[None, None], name='gt_label')
  252. inputs['is_difficult'] = fluid.data(
  253. dtype='int32', shape=[None, None], name='is_difficult')
  254. elif self.mode == 'test':
  255. inputs['im_size'] = fluid.data(
  256. dtype='int32', shape=[None, 2], name='im_size')
  257. return inputs
  258. def build_net(self, inputs):
  259. image = inputs['image']
  260. if self.mode == 'train':
  261. if isinstance(self.train_random_shapes,
  262. (list, tuple)) and len(self.train_random_shapes) > 0:
  263. import numpy as np
  264. shapes = np.array(self.train_random_shapes)
  265. shapes = np.stack([shapes, shapes], axis=1).astype('float32')
  266. shapes_tensor = fluid.layers.assign(shapes)
  267. index = fluid.layers.uniform_random(
  268. shape=[1], dtype='float32', min=0.0, max=1)
  269. index = fluid.layers.cast(
  270. index * len(self.train_random_shapes), dtype='int32')
  271. shape = fluid.layers.gather(shapes_tensor, index)
  272. shape = fluid.layers.reshape(shape, [-1])
  273. shape = fluid.layers.cast(shape, dtype='int32')
  274. image = fluid.layers.resize_nearest(
  275. image, out_shape=shape, align_corners=False)
  276. feats = self.backbone(image)
  277. if isinstance(feats, OrderedDict):
  278. feat_names = list(feats.keys())
  279. feats = [feats[name] for name in feat_names]
  280. head_outputs = self._head(feats)
  281. if self.mode == 'train':
  282. gt_box = inputs['gt_box']
  283. gt_label = inputs['gt_label']
  284. gt_score = inputs['gt_score']
  285. im_size = inputs['im_size']
  286. num_boxes = fluid.layers.shape(gt_box)[1]
  287. im_size_wh = fluid.layers.reverse(im_size, axis=1)
  288. whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
  289. whwh = fluid.layers.unsqueeze(whwh, axes=[1])
  290. whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
  291. whwh = fluid.layers.cast(whwh, dtype='float32')
  292. whwh.stop_gradient = True
  293. normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
  294. return self._get_loss(head_outputs, normalized_box, gt_label,
  295. gt_score)
  296. else:
  297. im_size = inputs['im_size']
  298. return self._get_prediction(head_outputs, im_size)