interpretation_algorithms.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import time
  18. from . import lime_base
  19. from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, gen_user_home
  20. from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
  21. from paddlex.interpret.as_data_reader.readers import read_image
  22. import paddlex.utils.logging as logging
  23. import cv2
  24. class CAM(object):
  25. def __init__(self, predict_fn, label_names):
  26. """
  27. Args:
  28. predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
  29. output: [
  30. logits [N, num_classes],
  31. feature map before global average pooling [N, num_channels, h_, w_]
  32. ]
  33. """
  34. self.predict_fn = predict_fn
  35. self.label_names = label_names
  36. def preparation_cam(self, data_):
  37. image_show = read_image(data_)
  38. result = self.predict_fn(image_show)
  39. logit = result[0][0]
  40. if abs(np.sum(logit) - 1.0) > 1e-4:
  41. # softmax
  42. logit = logit - np.max(logit)
  43. exp_result = np.exp(logit)
  44. probability = exp_result / np.sum(exp_result)
  45. else:
  46. probability = logit
  47. # only interpret top 1
  48. pred_label = np.argsort(probability)
  49. pred_label = pred_label[-1:]
  50. self.predicted_label = pred_label[0]
  51. self.predicted_probability = probability[pred_label[0]]
  52. self.image = image_show[0]
  53. self.labels = pred_label
  54. fc_weights = paddle_get_fc_weights()
  55. feature_maps = result[1]
  56. l = pred_label[0]
  57. ln = l
  58. if self.label_names is not None:
  59. ln = self.label_names[l]
  60. prob_str = "%.3f" % (probability[pred_label[0]])
  61. logging.info("predicted result: {} with probability {}.".format(
  62. ln, prob_str))
  63. return feature_maps, fc_weights
  64. def interpret(self,
  65. data_,
  66. visualization=True,
  67. save_to_disk=True,
  68. save_outdir=None):
  69. feature_maps, fc_weights = self.preparation_cam(data_)
  70. cam = get_cam(self.image, feature_maps, fc_weights,
  71. self.predicted_label)
  72. if visualization or save_to_disk:
  73. import matplotlib.pyplot as plt
  74. from skimage.segmentation import mark_boundaries
  75. l = self.labels[0]
  76. ln = l
  77. if self.label_names is not None:
  78. ln = self.label_names[l]
  79. psize = 5
  80. nrows = 1
  81. ncols = 2
  82. plt.close()
  83. f, axes = plt.subplots(
  84. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  85. for ax in axes.ravel():
  86. ax.axis("off")
  87. axes = axes.ravel()
  88. axes[0].imshow(self.image)
  89. prob_str = "{%.3f}" % (self.predicted_probability)
  90. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  91. axes[1].imshow(cam)
  92. axes[1].set_title("CAM")
  93. if save_to_disk and save_outdir is not None:
  94. os.makedirs(save_outdir, exist_ok=True)
  95. save_fig(data_, save_outdir, 'cam')
  96. if visualization:
  97. plt.show()
  98. return
  99. class LIME(object):
  100. def __init__(self,
  101. predict_fn,
  102. label_names,
  103. num_samples=3000,
  104. batch_size=50):
  105. """
  106. LIME wrapper. See lime_base.py for the detailed LIME implementation.
  107. Args:
  108. predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
  109. num_samples: the number of samples that LIME takes for fitting.
  110. batch_size: batch size for model inference each time.
  111. """
  112. self.num_samples = num_samples
  113. self.batch_size = batch_size
  114. self.predict_fn = predict_fn
  115. self.labels = None
  116. self.image = None
  117. self.lime_interpreter = None
  118. self.label_names = label_names
  119. def preparation_lime(self, data_):
  120. image_show = read_image(data_)
  121. result = self.predict_fn(image_show)
  122. result = result[0] # only one image here.
  123. if abs(np.sum(result) - 1.0) > 1e-4:
  124. # softmax
  125. result = result - np.max(result)
  126. exp_result = np.exp(result)
  127. probability = exp_result / np.sum(exp_result)
  128. else:
  129. probability = result
  130. # only interpret top 1
  131. pred_label = np.argsort(probability)
  132. pred_label = pred_label[-1:]
  133. self.predicted_label = pred_label[0]
  134. self.predicted_probability = probability[pred_label[0]]
  135. self.image = image_show[0]
  136. self.labels = pred_label
  137. l = pred_label[0]
  138. ln = l
  139. if self.label_names is not None:
  140. ln = self.label_names[l]
  141. prob_str = "%.3f" % (probability[pred_label[0]])
  142. logging.info("predicted result: {} with probability {}.".format(
  143. ln, prob_str))
  144. end = time.time()
  145. algo = lime_base.LimeImageInterpreter()
  146. interpreter = algo.interpret_instance(
  147. self.image,
  148. self.predict_fn,
  149. self.labels,
  150. 0,
  151. num_samples=self.num_samples,
  152. batch_size=self.batch_size)
  153. self.lime_interpreter = interpreter
  154. logging.info('lime time: ' + str(time.time() - end) + 's.')
  155. def interpret(self,
  156. data_,
  157. visualization=True,
  158. save_to_disk=True,
  159. save_outdir=None):
  160. if self.lime_interpreter is None:
  161. self.preparation_lime(data_)
  162. if visualization or save_to_disk:
  163. import matplotlib.pyplot as plt
  164. from skimage.segmentation import mark_boundaries
  165. l = self.labels[0]
  166. ln = l
  167. if self.label_names is not None:
  168. ln = self.label_names[l]
  169. psize = 5
  170. nrows = 2
  171. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  172. ncols = len(weights_choices)
  173. plt.close()
  174. f, axes = plt.subplots(
  175. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  176. for ax in axes.ravel():
  177. ax.axis("off")
  178. axes = axes.ravel()
  179. axes[0].imshow(self.image)
  180. prob_str = "{%.3f}" % (self.predicted_probability)
  181. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  182. axes[1].imshow(
  183. mark_boundaries(self.image, self.lime_interpreter.segments))
  184. axes[1].set_title("superpixel segmentation")
  185. # LIME visualization
  186. for i, w in enumerate(weights_choices):
  187. num_to_show = auto_choose_num_features_to_show(
  188. self.lime_interpreter, l, w)
  189. temp, mask = self.lime_interpreter.get_image_and_mask(
  190. l,
  191. positive_only=False,
  192. hide_rest=False,
  193. num_features=num_to_show)
  194. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  195. axes[ncols + i].set_title(
  196. "label {}, first {} superpixels".format(ln, num_to_show))
  197. if save_to_disk and save_outdir is not None:
  198. os.makedirs(save_outdir, exist_ok=True)
  199. save_fig(data_, save_outdir, 'lime', self.num_samples)
  200. if visualization:
  201. plt.show()
  202. return
  203. class NormLIMEStandard(object):
  204. def __init__(self,
  205. predict_fn,
  206. label_names,
  207. num_samples=3000,
  208. batch_size=50,
  209. kmeans_model_for_normlime=None,
  210. normlime_weights=None):
  211. root_path = gen_user_home()
  212. root_path = osp.join(root_path, '.paddlex')
  213. h_pre_models = osp.join(root_path, "pre_models")
  214. if not osp.exists(h_pre_models):
  215. if not osp.exists(root_path):
  216. os.makedirs(root_path)
  217. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  218. pdx.utils.download_and_decompress(url, path=root_path)
  219. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  220. if kmeans_model_for_normlime is None:
  221. try:
  222. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  223. except:
  224. raise ValueError(
  225. "NormLIME needs the KMeans model, where we provided a default one in "
  226. "pre_models/kmeans_model.pkl.")
  227. else:
  228. logging.debug("Warning: It is *strongly* suggested to use the \
  229. default KMeans model in pre_models/kmeans_model.pkl. \
  230. Use another one will change the final result.")
  231. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  232. self.num_samples = num_samples
  233. self.batch_size = batch_size
  234. try:
  235. self.normlime_weights = np.load(
  236. normlime_weights, allow_pickle=True).item()
  237. except:
  238. self.normlime_weights = None
  239. logging.debug(
  240. "Warning: not find the correct precomputed Normlime result.")
  241. self.predict_fn = predict_fn
  242. self.labels = None
  243. self.image = None
  244. self.label_names = label_names
  245. def predict_cluster_labels(self, feature_map, segments):
  246. X = get_feature_for_kmeans(feature_map, segments)
  247. try:
  248. cluster_labels = self.kmeans_model.predict(X)
  249. except AttributeError:
  250. from sklearn.metrics import pairwise_distances_argmin_min
  251. cluster_labels, _ = pairwise_distances_argmin_min(
  252. X, self.kmeans_model.cluster_centers_)
  253. return cluster_labels
  254. def predict_using_normlime_weights(self, pred_labels,
  255. predicted_cluster_labels):
  256. # global weights
  257. g_weights = {y: [] for y in pred_labels}
  258. for y in pred_labels:
  259. cluster_weights_y = self.normlime_weights.get(y, {})
  260. g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
  261. for i, k in enumerate(predicted_cluster_labels)]
  262. g_weights[y] = sorted(
  263. g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
  264. return g_weights
  265. def preparation_normlime(self, data_):
  266. self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
  267. self.batch_size)
  268. self._lime.preparation_lime(data_)
  269. image_show = read_image(data_)
  270. self.predicted_label = self._lime.predicted_label
  271. self.predicted_probability = self._lime.predicted_probability
  272. self.image = image_show[0]
  273. self.labels = self._lime.labels
  274. logging.info('performing NormLIME operations ...')
  275. cluster_labels = self.predict_cluster_labels(
  276. compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
  277. self._lime.lime_interpreter.segments)
  278. g_weights = self.predict_using_normlime_weights(self.labels,
  279. cluster_labels)
  280. return g_weights
  281. def interpret(self,
  282. data_,
  283. visualization=True,
  284. save_to_disk=True,
  285. save_outdir=None):
  286. if self.normlime_weights is None:
  287. raise ValueError(
  288. "Not find the correct precomputed NormLIME result. \n"
  289. "\t Try to call compute_normlime_weights() first or load the correct path."
  290. )
  291. g_weights = self.preparation_normlime(data_)
  292. lime_weights = self._lime.lime_interpreter.local_weights
  293. if visualization or save_to_disk:
  294. import matplotlib.pyplot as plt
  295. from skimage.segmentation import mark_boundaries
  296. l = self.labels[0]
  297. ln = l
  298. if self.label_names is not None:
  299. ln = self.label_names[l]
  300. psize = 5
  301. nrows = 4
  302. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  303. nums_to_show = []
  304. ncols = len(weights_choices)
  305. plt.close()
  306. f, axes = plt.subplots(
  307. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  308. for ax in axes.ravel():
  309. ax.axis("off")
  310. axes = axes.ravel()
  311. axes[0].imshow(self.image)
  312. prob_str = "{%.3f}" % (self.predicted_probability)
  313. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  314. axes[1].imshow(
  315. mark_boundaries(self.image,
  316. self._lime.lime_interpreter.segments))
  317. axes[1].set_title("superpixel segmentation")
  318. # LIME visualization
  319. for i, w in enumerate(weights_choices):
  320. num_to_show = auto_choose_num_features_to_show(
  321. self._lime.lime_interpreter, l, w)
  322. nums_to_show.append(num_to_show)
  323. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  324. l,
  325. positive_only=False,
  326. hide_rest=False,
  327. num_features=num_to_show)
  328. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  329. axes[ncols + i].set_title("LIME: first {} superpixels".format(
  330. num_to_show))
  331. # NormLIME visualization
  332. self._lime.lime_interpreter.local_weights = g_weights
  333. for i, num_to_show in enumerate(nums_to_show):
  334. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  335. l,
  336. positive_only=False,
  337. hide_rest=False,
  338. num_features=num_to_show)
  339. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  340. axes[ncols * 2 + i].set_title(
  341. "NormLIME: first {} superpixels".format(num_to_show))
  342. # NormLIME*LIME visualization
  343. combined_weights = combine_normlime_and_lime(lime_weights,
  344. g_weights)
  345. self._lime.lime_interpreter.local_weights = combined_weights
  346. for i, num_to_show in enumerate(nums_to_show):
  347. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  348. l,
  349. positive_only=False,
  350. hide_rest=False,
  351. num_features=num_to_show)
  352. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  353. axes[ncols * 3 + i].set_title(
  354. "Combined: first {} superpixels".format(num_to_show))
  355. self._lime.lime_interpreter.local_weights = lime_weights
  356. if save_to_disk and save_outdir is not None:
  357. os.makedirs(save_outdir, exist_ok=True)
  358. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  359. if visualization:
  360. plt.show()
  361. class NormLIME(object):
  362. def __init__(self,
  363. predict_fn,
  364. label_names,
  365. num_samples=3000,
  366. batch_size=50,
  367. kmeans_model_for_normlime=None,
  368. normlime_weights=None):
  369. root_path = gen_user_home()
  370. root_path = osp.join(root_path, '.paddlex')
  371. h_pre_models = osp.join(root_path, "pre_models")
  372. if not osp.exists(h_pre_models):
  373. if not osp.exists(root_path):
  374. os.makedirs(root_path)
  375. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  376. pdx.utils.download_and_decompress(url, path=root_path)
  377. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  378. if kmeans_model_for_normlime is None:
  379. try:
  380. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  381. except:
  382. raise ValueError(
  383. "NormLIME needs the KMeans model, where we provided a default one in "
  384. "pre_models/kmeans_model.pkl.")
  385. else:
  386. logging.debug("Warning: It is *strongly* suggested to use the \
  387. default KMeans model in pre_models/kmeans_model.pkl. \
  388. Use another one will change the final result.")
  389. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  390. self.num_samples = num_samples
  391. self.batch_size = batch_size
  392. try:
  393. self.normlime_weights = np.load(
  394. normlime_weights, allow_pickle=True).item()
  395. except:
  396. self.normlime_weights = None
  397. logging.debug(
  398. "Warning: not find the correct precomputed Normlime result.")
  399. self.predict_fn = predict_fn
  400. self.labels = None
  401. self.image = None
  402. self.label_names = label_names
  403. def predict_cluster_labels(self, feature_map, segments):
  404. X = get_feature_for_kmeans(feature_map, segments)
  405. try:
  406. cluster_labels = self.kmeans_model.predict(X)
  407. except AttributeError:
  408. from sklearn.metrics import pairwise_distances_argmin_min
  409. cluster_labels, _ = pairwise_distances_argmin_min(
  410. X, self.kmeans_model.cluster_centers_)
  411. return cluster_labels
  412. def predict_using_normlime_weights(self, pred_labels,
  413. predicted_cluster_labels):
  414. # global weights
  415. g_weights = {y: [] for y in pred_labels}
  416. for y in pred_labels:
  417. cluster_weights_y = self.normlime_weights.get(y, {})
  418. g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
  419. for i, k in enumerate(predicted_cluster_labels)]
  420. g_weights[y] = sorted(
  421. g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
  422. return g_weights
  423. def preparation_normlime(self, data_):
  424. self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
  425. self.batch_size)
  426. self._lime.preparation_lime(data_)
  427. image_show = read_image(data_)
  428. self.predicted_label = self._lime.predicted_label
  429. self.predicted_probability = self._lime.predicted_probability
  430. self.image = image_show[0]
  431. self.labels = self._lime.labels
  432. logging.info('performing NormLIME operations ...')
  433. cluster_labels = self.predict_cluster_labels(
  434. compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
  435. self._lime.lime_interpreter.segments)
  436. g_weights = self.predict_using_normlime_weights(self.labels,
  437. cluster_labels)
  438. return g_weights
  439. def interpret(self,
  440. data_,
  441. visualization=True,
  442. save_to_disk=True,
  443. save_outdir=None):
  444. if self.normlime_weights is None:
  445. raise ValueError(
  446. "Not find the correct precomputed NormLIME result. \n"
  447. "\t Try to call compute_normlime_weights() first or load the correct path."
  448. )
  449. g_weights = self.preparation_normlime(data_)
  450. lime_weights = self._lime.lime_interpreter.local_weights
  451. if visualization or save_to_disk:
  452. import matplotlib.pyplot as plt
  453. from skimage.segmentation import mark_boundaries
  454. l = self.labels[0]
  455. ln = l
  456. if self.label_names is not None:
  457. ln = self.label_names[l]
  458. psize = 5
  459. nrows = 4
  460. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  461. nums_to_show = []
  462. ncols = len(weights_choices)
  463. plt.close()
  464. f, axes = plt.subplots(
  465. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  466. for ax in axes.ravel():
  467. ax.axis("off")
  468. axes = axes.ravel()
  469. axes[0].imshow(self.image)
  470. prob_str = "{%.3f}" % (self.predicted_probability)
  471. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  472. axes[1].imshow(
  473. mark_boundaries(self.image,
  474. self._lime.lime_interpreter.segments))
  475. axes[1].set_title("superpixel segmentation")
  476. # LIME visualization
  477. for i, w in enumerate(weights_choices):
  478. num_to_show = auto_choose_num_features_to_show(
  479. self._lime.lime_interpreter, l, w)
  480. nums_to_show.append(num_to_show)
  481. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  482. l,
  483. positive_only=True,
  484. hide_rest=False,
  485. num_features=num_to_show)
  486. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  487. axes[ncols + i].set_title("LIME: first {} superpixels".format(
  488. num_to_show))
  489. # NormLIME visualization
  490. self._lime.lime_interpreter.local_weights = g_weights
  491. for i, num_to_show in enumerate(nums_to_show):
  492. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  493. l,
  494. positive_only=True,
  495. hide_rest=False,
  496. num_features=num_to_show)
  497. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  498. axes[ncols * 2 + i].set_title(
  499. "NormLIME: first {} superpixels".format(num_to_show))
  500. # NormLIME*LIME visualization
  501. combined_weights = combine_normlime_and_lime(lime_weights,
  502. g_weights)
  503. self._lime.lime_interpreter.local_weights = combined_weights
  504. for i, num_to_show in enumerate(nums_to_show):
  505. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  506. l,
  507. positive_only=True,
  508. hide_rest=False,
  509. num_features=num_to_show)
  510. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  511. axes[ncols * 3 + i].set_title(
  512. "Combined: first {} superpixels".format(num_to_show))
  513. self._lime.lime_interpreter.local_weights = lime_weights
  514. if save_to_disk and save_outdir is not None:
  515. os.makedirs(save_outdir, exist_ok=True)
  516. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  517. if visualization:
  518. plt.show()
  519. def auto_choose_num_features_to_show(lime_interpreter, label,
  520. percentage_to_show):
  521. segments = lime_interpreter.segments
  522. lime_weights = lime_interpreter.local_weights[label]
  523. num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[
  524. 1] // len(np.unique(segments)) // 8
  525. # l1 norm with filtered weights.
  526. used_weights = [(tuple_w[0], tuple_w[1])
  527. for i, tuple_w in enumerate(lime_weights)
  528. if tuple_w[1] > 0]
  529. norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
  530. normalized_weights = [(tuple_w[0], tuple_w[1] / norm)
  531. for i, tuple_w in enumerate(lime_weights)]
  532. a = 0.0
  533. n = 0
  534. for i, tuple_w in enumerate(normalized_weights):
  535. if tuple_w[1] < 0:
  536. continue
  537. if len(np.where(segments == tuple_w[0])[
  538. 0]) < num_pixels_threshold_in_a_sp:
  539. continue
  540. a += tuple_w[1]
  541. if a > percentage_to_show:
  542. n = i + 1
  543. break
  544. if percentage_to_show <= 0.0:
  545. return 5
  546. if n == 0:
  547. return auto_choose_num_features_to_show(lime_interpreter, label,
  548. percentage_to_show - 0.1)
  549. return n
  550. def get_cam(image_show,
  551. feature_maps,
  552. fc_weights,
  553. label_index,
  554. cam_min=None,
  555. cam_max=None):
  556. _, nc, h, w = feature_maps.shape
  557. cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
  558. cam = cam.sum((0, 1))
  559. if cam_min is None:
  560. cam_min = np.min(cam)
  561. if cam_max is None:
  562. cam_max = np.max(cam)
  563. cam = cam - cam_min
  564. cam = cam / cam_max
  565. cam = np.uint8(255 * cam)
  566. cam_img = cv2.resize(
  567. cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
  568. heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
  569. heatmap = np.float32(heatmap)
  570. cam = heatmap + np.float32(image_show)
  571. cam = cam / np.max(cam)
  572. return cam
  573. def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
  574. import matplotlib.pyplot as plt
  575. if isinstance(data_, str):
  576. if algorithm_name == 'cam':
  577. f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
  578. else:
  579. f_out = "{}_{}_s{}.png".format(algorithm_name,
  580. data_.split('/')[-1], num_samples)
  581. plt.savefig(os.path.join(save_outdir, f_out))
  582. else:
  583. n = 0
  584. if algorithm_name == 'cam':
  585. f_out = 'cam-{}.png'.format(n)
  586. else:
  587. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  588. while os.path.exists(os.path.join(save_outdir, f_out)):
  589. n += 1
  590. if algorithm_name == 'cam':
  591. f_out = 'cam-{}.png'.format(n)
  592. else:
  593. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  594. continue
  595. plt.savefig(os.path.join(save_outdir, f_out))
  596. logging.info('The image of intrepretation result save in {}'.format(
  597. os.path.join(save_outdir, f_out)))