train.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. from collections import deque
  17. import shutil
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddlex.paddleseg.utils import TimeAverager, calculate_eta, resume, logger
  21. from paddlex.paddleseg.core.val import evaluate
  22. def check_logits_losses(logits_list, losses):
  23. len_logits = len(logits_list)
  24. len_losses = len(losses['types'])
  25. if len_logits != len_losses:
  26. raise RuntimeError(
  27. 'The length of logits_list should equal to the types of loss config: {} != {}.'
  28. .format(len_logits, len_losses))
  29. def loss_computation(logits_list, labels, losses, edges=None):
  30. check_logits_losses(logits_list, losses)
  31. loss_list = []
  32. for i in range(len(logits_list)):
  33. logits = logits_list[i]
  34. loss_i = losses['types'][i]
  35. # Whether to use edges as labels According to loss type.
  36. if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
  37. loss_list.append(losses['coef'][i] * loss_i(logits, edges))
  38. else:
  39. loss_list.append(losses['coef'][i] * loss_i(logits, labels))
  40. return loss_list
  41. def train(model,
  42. train_dataset,
  43. val_dataset=None,
  44. optimizer=None,
  45. save_dir='output',
  46. iters=10000,
  47. batch_size=2,
  48. resume_model=None,
  49. save_interval=1000,
  50. log_iters=10,
  51. num_workers=0,
  52. use_vdl=False,
  53. losses=None,
  54. keep_checkpoint_max=5):
  55. """
  56. Launch training.
  57. Args:
  58. model(nn.Layer): A sementic segmentation model.
  59. train_dataset (paddle.io.Dataset): Used to read and process training datasets.
  60. val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
  61. optimizer (paddle.optimizer.Optimizer): The optimizer.
  62. save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
  63. iters (int, optional): How may iters to train the model. Defualt: 10000.
  64. batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
  65. resume_model (str, optional): The path of resume model.
  66. save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
  67. log_iters (int, optional): Display logging information at every log_iters. Default: 10.
  68. num_workers (int, optional): Num workers for data loader. Default: 0.
  69. use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
  70. losses (dict): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
  71. The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
  72. keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
  73. """
  74. model.train()
  75. nranks = paddle.distributed.ParallelEnv().nranks
  76. local_rank = paddle.distributed.ParallelEnv().local_rank
  77. start_iter = 0
  78. if resume_model is not None:
  79. start_iter = resume(model, optimizer, resume_model)
  80. if not os.path.isdir(save_dir):
  81. if os.path.exists(save_dir):
  82. os.remove(save_dir)
  83. os.makedirs(save_dir)
  84. if nranks > 1:
  85. # Initialize parallel environment if not done.
  86. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  87. ):
  88. paddle.distributed.init_parallel_env()
  89. ddp_model = paddle.DataParallel(model)
  90. else:
  91. ddp_model = paddle.DataParallel(model)
  92. batch_sampler = paddle.io.DistributedBatchSampler(
  93. train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
  94. loader = paddle.io.DataLoader(
  95. train_dataset,
  96. batch_sampler=batch_sampler,
  97. num_workers=num_workers,
  98. return_list=True, )
  99. if use_vdl:
  100. from visualdl import LogWriter
  101. log_writer = LogWriter(save_dir)
  102. avg_loss = 0.0
  103. avg_loss_list = []
  104. iters_per_epoch = len(batch_sampler)
  105. best_mean_iou = -1.0
  106. best_model_iter = -1
  107. reader_cost_averager = TimeAverager()
  108. batch_cost_averager = TimeAverager()
  109. save_models = deque()
  110. batch_start = time.time()
  111. iter = start_iter
  112. while iter < iters:
  113. for data in loader:
  114. iter += 1
  115. if iter > iters:
  116. break
  117. reader_cost_averager.record(time.time() - batch_start)
  118. images = data[0]
  119. labels = data[1].astype('int64')
  120. edges = None
  121. if len(data) == 3:
  122. edges = data[2].astype('int64')
  123. if nranks > 1:
  124. logits_list = ddp_model(images)
  125. else:
  126. logits_list = model(images)
  127. loss_list = loss_computation(
  128. logits_list=logits_list,
  129. labels=labels,
  130. losses=losses,
  131. edges=edges)
  132. loss = sum(loss_list)
  133. loss.backward()
  134. optimizer.step()
  135. lr = optimizer.get_lr()
  136. if isinstance(optimizer._learning_rate,
  137. paddle.optimizer.lr.LRScheduler):
  138. optimizer._learning_rate.step()
  139. model.clear_gradients()
  140. avg_loss += loss.numpy()[0]
  141. if not avg_loss_list:
  142. avg_loss_list = [l.numpy() for l in loss_list]
  143. else:
  144. for i in range(len(loss_list)):
  145. avg_loss_list[i] += loss_list[i].numpy()
  146. batch_cost_averager.record(
  147. time.time() - batch_start, num_samples=batch_size)
  148. if (iter) % log_iters == 0 and local_rank == 0:
  149. avg_loss /= log_iters
  150. avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
  151. remain_iters = iters - iter
  152. avg_train_batch_cost = batch_cost_averager.get_average()
  153. avg_train_reader_cost = reader_cost_averager.get_average()
  154. eta = calculate_eta(remain_iters, avg_train_batch_cost)
  155. logger.info(
  156. "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
  157. .format((iter - 1
  158. ) // iters_per_epoch + 1, iter, iters, avg_loss,
  159. lr, avg_train_batch_cost, avg_train_reader_cost,
  160. batch_cost_averager.get_ips_average(), eta))
  161. if use_vdl:
  162. log_writer.add_scalar('Train/loss', avg_loss, iter)
  163. # Record all losses if there are more than 2 losses.
  164. if len(avg_loss_list) > 1:
  165. avg_loss_dict = {}
  166. for i, value in enumerate(avg_loss_list):
  167. avg_loss_dict['loss_' + str(i)] = value
  168. for key, value in avg_loss_dict.items():
  169. log_tag = 'Train/' + key
  170. log_writer.add_scalar(log_tag, value, iter)
  171. log_writer.add_scalar('Train/lr', lr, iter)
  172. log_writer.add_scalar('Train/batch_cost',
  173. avg_train_batch_cost, iter)
  174. log_writer.add_scalar('Train/reader_cost',
  175. avg_train_reader_cost, iter)
  176. avg_loss = 0.0
  177. avg_loss_list = []
  178. reader_cost_averager.reset()
  179. batch_cost_averager.reset()
  180. if (iter % save_interval == 0 or
  181. iter == iters) and (val_dataset is not None):
  182. num_workers = 1 if num_workers > 0 else 0
  183. mean_iou, acc, _, _, _ = evaluate(
  184. model, val_dataset, num_workers=num_workers)
  185. model.train()
  186. if (iter % save_interval == 0 or
  187. iter == iters) and local_rank == 0:
  188. current_save_dir = os.path.join(save_dir,
  189. "iter_{}".format(iter))
  190. if not os.path.isdir(current_save_dir):
  191. os.makedirs(current_save_dir)
  192. paddle.save(model.state_dict(),
  193. os.path.join(current_save_dir, 'model.pdparams'))
  194. paddle.save(optimizer.state_dict(),
  195. os.path.join(current_save_dir, 'model.pdopt'))
  196. save_models.append(current_save_dir)
  197. if len(save_models) > keep_checkpoint_max > 0:
  198. model_to_remove = save_models.popleft()
  199. shutil.rmtree(model_to_remove)
  200. if val_dataset is not None:
  201. if mean_iou > best_mean_iou:
  202. best_mean_iou = mean_iou
  203. best_model_iter = iter
  204. best_model_dir = os.path.join(save_dir, "best_model")
  205. paddle.save(
  206. model.state_dict(),
  207. os.path.join(best_model_dir, 'model.pdparams'))
  208. logger.info(
  209. '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
  210. .format(best_mean_iou, best_model_iter))
  211. if use_vdl:
  212. log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
  213. log_writer.add_scalar('Evaluate/Acc', acc, iter)
  214. batch_start = time.time()
  215. # Calculate flops.
  216. if local_rank == 0:
  217. def count_syncbn(m, x, y):
  218. x = x[0]
  219. nelements = x.numel()
  220. m.total_ops += int(2 * nelements)
  221. _, c, h, w = images.shape
  222. flops = paddle.flops(
  223. model, [1, c, h, w],
  224. custom_ops={paddle.nn.SyncBatchNorm: count_syncbn})
  225. # Sleep for half a second to let dataloader release resources.
  226. time.sleep(0.5)
  227. if use_vdl:
  228. log_writer.close()