prune.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import yaml
  16. import time
  17. import pickle
  18. import os
  19. import os.path as osp
  20. from functools import reduce
  21. import paddle.fluid as fluid
  22. from multiprocessing import Process, Queue
  23. import paddleslim
  24. from paddleslim.prune import Pruner, load_sensitivities
  25. from paddleslim.core import GraphWrapper
  26. from .prune_config import get_prune_params
  27. import paddlex.utils.logging as logging
  28. from paddlex.utils import seconds_to_hms
  29. def sensitivity(program,
  30. place,
  31. param_names,
  32. eval_func,
  33. sensitivities_file=None,
  34. pruned_ratios=None):
  35. scope = fluid.global_scope()
  36. graph = GraphWrapper(program)
  37. sensitivities = load_sensitivities(sensitivities_file)
  38. if pruned_ratios is None:
  39. pruned_ratios = np.arange(0.1, 1, step=0.1)
  40. total_evaluate_iters = 0
  41. for name in param_names:
  42. if name not in sensitivities:
  43. sensitivities[name] = {}
  44. total_evaluate_iters += len(list(pruned_ratios))
  45. else:
  46. total_evaluate_iters += (
  47. len(list(pruned_ratios)) - len(sensitivities[name]))
  48. eta = '-'
  49. start_time = time.time()
  50. baseline = eval_func(graph.program)
  51. cost = time.time() - start_time
  52. eta = cost * (total_evaluate_iters - 1)
  53. current_iter = 1
  54. for name in sensitivities:
  55. for ratio in pruned_ratios:
  56. if ratio in sensitivities[name]:
  57. logging.debug('{}, {} has computed.'.format(name, ratio))
  58. continue
  59. progress = float(current_iter) / total_evaluate_iters
  60. progress = "%.2f%%" % (progress * 100)
  61. logging.info(
  62. "Total evaluate iters={}, current={}, progress={}, eta={}".
  63. format(total_evaluate_iters, current_iter, progress,
  64. seconds_to_hms(
  65. int(cost * (total_evaluate_iters - current_iter)))),
  66. use_color=True)
  67. current_iter += 1
  68. pruner = Pruner()
  69. logging.info("sensitive - param: {}; ratios: {}".format(name,
  70. ratio))
  71. pruned_program, param_backup, _ = pruner.prune(
  72. program=graph.program,
  73. scope=scope,
  74. params=[name],
  75. ratios=[ratio],
  76. place=place,
  77. lazy=True,
  78. only_graph=False,
  79. param_backup=True)
  80. pruned_metric = eval_func(pruned_program)
  81. loss = (baseline - pruned_metric) / baseline
  82. logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
  83. loss))
  84. sensitivities[name][ratio] = loss
  85. with open(sensitivities_file, 'wb') as f:
  86. pickle.dump(sensitivities, f)
  87. for param_name in param_backup.keys():
  88. param_t = scope.find_var(param_name).get_tensor()
  89. param_t.set(param_backup[param_name], place)
  90. return sensitivities
  91. def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
  92. """通道裁剪。
  93. Args:
  94. program (paddle.fluid.Program): 需要裁剪的Program,Program的具体介绍可参见
  95. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  96. prune_names (list): 由裁剪参数名组成的参数列表。
  97. prune_ratios (list): 由裁剪率组成的参数列表,与prune_names中的参数列表意义对应。
  98. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  99. only_graph (bool): 是否只修改网络图,当为False时代表同时修改网络图和
  100. scope(全局作用域)中的参数。默认为False。
  101. Returns:
  102. paddle.fluid.Program: 裁剪后的Program。
  103. """
  104. prog_var_shape_dict = {}
  105. for var in program.list_vars():
  106. try:
  107. prog_var_shape_dict[var.name] = var.shape
  108. except Exception:
  109. pass
  110. index = 0
  111. for param, ratio in zip(prune_names, prune_ratios):
  112. origin_num = prog_var_shape_dict[param][0]
  113. pruned_num = int(round(origin_num * ratio))
  114. while origin_num == pruned_num:
  115. ratio -= 0.1
  116. pruned_num = int(round(origin_num * (ratio)))
  117. prune_ratios[index] = ratio
  118. index += 1
  119. scope = fluid.global_scope()
  120. pruner = Pruner()
  121. program, _, _ = pruner.prune(
  122. program,
  123. scope,
  124. params=prune_names,
  125. ratios=prune_ratios,
  126. place=place,
  127. lazy=False,
  128. only_graph=only_graph,
  129. param_backup=False,
  130. param_shape_backup=False)
  131. return program
  132. def prune_program(model, prune_params_ratios=None):
  133. """根据裁剪参数和裁剪率裁剪Program。
  134. 1. 裁剪训练Program和测试Program。
  135. 2. 使用裁剪后的Program更新模型中的train_prog和test_prog。
  136. 【注意】Program的具体介绍可参见
  137. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  138. Args:
  139. model (paddlex.cv.models): paddlex中的模型。
  140. prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
  141. 使用默认裁剪参数名和裁剪率。默认为None。
  142. """
  143. place = model.places[0]
  144. train_prog = model.train_prog
  145. eval_prog = model.test_prog
  146. valid_prune_names = get_prune_params(model)
  147. assert set(list(prune_params_ratios.keys())) & set(valid_prune_names), \
  148. "All params in 'prune_params_ratios' can't be pruned!"
  149. prune_names = list(
  150. set(list(prune_params_ratios.keys())) & set(valid_prune_names))
  151. prune_ratios = [
  152. prune_params_ratios[prune_name] for prune_name in prune_names
  153. ]
  154. model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
  155. place)
  156. model.test_prog = channel_prune(
  157. eval_prog, prune_names, prune_ratios, place, only_graph=True)
  158. def update_program(program, model_dir, place):
  159. """根据裁剪信息更新Program和参数。
  160. Args:
  161. program (paddle.fluid.Program): 需要更新的Program,Program的具体介绍可参见
  162. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  163. model_dir (str): 模型存储路径。
  164. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  165. Returns:
  166. paddle.fluid.Program: 更新后的Program。
  167. """
  168. graph = GraphWrapper(program)
  169. with open(osp.join(model_dir, "prune.yml")) as f:
  170. shapes = yaml.load(f.read(), Loader=yaml.Loader)
  171. for param, shape in shapes.items():
  172. graph.var(param).set_shape(shape)
  173. for block in program.blocks:
  174. for param in block.all_parameters():
  175. if param.name in shapes:
  176. param_tensor = fluid.global_scope().find_var(
  177. param.name).get_tensor()
  178. param_tensor.set(
  179. np.zeros(list(shapes[param.name])).astype('float32'),
  180. place)
  181. graph.update_groups_of_conv()
  182. graph.infer_shape()
  183. return program
  184. def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
  185. """计算模型中可裁剪卷积Kernel的敏感度。
  186. 1. 获取模型中可裁剪卷积Kernel的名称。
  187. 2. 计算每个可裁剪卷积Kernel不同裁剪率下的敏感度。
  188. 【注意】卷积的敏感度是指在不同裁剪率下评估数据集预测精度的损失,
  189. 通过得到的敏感度,可以决定最终模型需要裁剪的参数列表和各裁剪参数对应的裁剪率。
  190. Args:
  191. model (paddlex.cv.models): paddlex中的模型。
  192. save_file (str): 计算的得到的sensetives文件存储路径。
  193. eval_dataset (paddlex.datasets): 验证数据读取器。
  194. batch_size (int): 验证数据批大小。默认为8。
  195. Returns:
  196. dict: 由参数名和不同裁剪率下敏感度组成的字典。存储的信息如下:
  197. .. code-block:: python
  198. {"weight_0":
  199. {0.1: 0.22,
  200. 0.2: 0.33
  201. },
  202. "weight_1":
  203. {0.1: 0.21,
  204. 0.2: 0.4
  205. }
  206. }
  207. 其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
  208. """
  209. if os.path.exists(save_file):
  210. os.remove(save_file)
  211. prune_names = get_prune_params(model)
  212. def eval_for_prune(program):
  213. eval_metrics = model.evaluate(
  214. eval_dataset=eval_dataset,
  215. batch_size=batch_size,
  216. return_details=False)
  217. primary_key = list(eval_metrics.keys())[0]
  218. return eval_metrics[primary_key]
  219. sensitivitives = sensitivity(
  220. model.test_prog,
  221. model.places[0],
  222. prune_names,
  223. eval_for_prune,
  224. sensitivities_file=save_file,
  225. pruned_ratios=list(np.arange(0.1, 1, 0.1)))
  226. return sensitivitives
  227. def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
  228. """根据设定的精度损失容忍度metric_loss_thresh和计算保存的模型参数敏感度信息文件sensetive_file,
  229. 获取裁剪的参数配置。
  230. 【注意】metric_loss_thresh并不确保最终裁剪后的模型在fine-tune后的模型效果,仅为预估值。
  231. Args:
  232. sensitivities_file (str): 敏感度文件存储路径。
  233. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  234. Returns:
  235. dict: 由参数名和裁剪率组成的字典。存储的信息如下:
  236. .. code-block:: python
  237. {"weight_0": 0.1,
  238. "weight_1": 0.2
  239. }
  240. 其中key是卷积Kernel名;value是裁剪率。
  241. """
  242. if not osp.exists(sensitivities_file):
  243. raise Exception('The sensitivities file is not exists!')
  244. sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
  245. <<<<<<< HEAD
  246. params_ratios = paddleslim.prune.get_ratios_by_loss(
  247. sensitivitives, eval_metric_loss)
  248. =======
  249. params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
  250. eval_metric_loss)
  251. >>>>>>> 7df89bb4e3e8cca1c9c57bf5f316fc9eb873a149
  252. return params_ratios
  253. def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
  254. """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
  255. Args:
  256. program (paddle.fluid.Program): 需要裁剪的Program,Program的具体介绍可参见
  257. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  258. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  259. sensitivities_file (str): 敏感度文件存储路径。
  260. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  261. Returns:
  262. float: 裁剪后模型大小相对于当前模型大小的比例。
  263. """
  264. prune_params_ratios = get_params_ratios(sensitivities_file,
  265. eval_metric_loss)
  266. prog_var_shape_dict = {}
  267. for var in program.list_vars():
  268. try:
  269. prog_var_shape_dict[var.name] = var.shape
  270. except Exception:
  271. pass
  272. for param, ratio in prune_params_ratios.items():
  273. origin_num = prog_var_shape_dict[param][0]
  274. pruned_num = int(round(origin_num * ratio))
  275. while origin_num == pruned_num:
  276. ratio -= 0.1
  277. pruned_num = int(round(origin_num * (ratio)))
  278. prune_params_ratios[param] = ratio
  279. prune_program = channel_prune(
  280. program,
  281. list(prune_params_ratios.keys()),
  282. list(prune_params_ratios.values()),
  283. place,
  284. only_graph=True)
  285. origin_size = 0
  286. new_size = 0
  287. for var in program.list_vars():
  288. name = var.name
  289. shape = var.shape
  290. for prune_block in prune_program.blocks:
  291. if prune_block.has_var(name):
  292. prune_var = prune_block.var(name)
  293. prune_shape = prune_var.shape
  294. break
  295. origin_size += reduce(lambda x, y: x * y, shape)
  296. new_size += reduce(lambda x, y: x * y, prune_shape)
  297. return (new_size * 1.0) / origin_size