! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
Captum
from __future__ import annotations
import tempfile
from fastai.basics import *
from nbdev.showdoc import *
Captum 是来自 PyTorch 的模型解释库,详细信息可访问 这里。
要使用它,我们需要通过以下方式安装该软件包:
conda install captum -c pytorch
或
pip install captum
这是使用 Captum 的回调。
from ipykernel import jsonutil
# 由于 json_clean 不支持 CategoryMap 类型,因此这是一个脏的解决方法。
=jsonutil.json_clean
_json_cleandef json_clean(o):
= list(o.items) if isinstance(o,CategoryMap) else o
o return _json_clean(o)
= json_clean jsonutil.json_clean
from captum.attr import IntegratedGradients,NoiseTunnel,GradientShap,Occlusion
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap
from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature
在整个笔记本中,我们将使用以下数据:
from fastai.vision.all import *
= untar_data(URLs.PETS)/'images'
path = get_image_files(path)
fnames def is_cat(x): return x[0].isupper()
= ImageDataLoaders.from_name_func(
dls =0.2, seed=42,
path, fnames, valid_pct=is_cat, item_tfms=Resize(128)) label_func
from random import randint
= vision_learner(dls, resnet34, metrics=error_rate)
learn 1) learn.fine_tune(
Captum 解释
Distill 文章 在这里 提供了关于选择基线图像的良好概述。我们可以逐个尝试。
class CaptumInterpretation():
"Captum Interpretation for Resnet"
def __init__(self,learn,cmap_name='custom blue',colors=None,N=256,methods=('original_image','heat_map'),
=("all", "positive"),outlier_perc=1):
signsif colors is None: colors = [(0, '#ffffff'),(0.25, '#000000'),(1, '#000000')]
store_attr()self.dls,self.model = learn.dls,self.learn.model
self.supported_metrics=['IG','NT','Occl']
def get_baseline_img(self, img_tensor,baseline_type):
=None
baseline_imgif baseline_type=='zeros': baseline_img= img_tensor*0
if baseline_type=='uniform': baseline_img= torch.rand(img_tensor.shape)
if baseline_type=='gauss':
= (torch.rand(img_tensor.shape).to(self.dls.device)+img_tensor)/2
baseline_imgreturn baseline_img.to(self.dls.device)
def visualize(self,inp,metric='IG',n_steps=1000,baseline_type='zeros',nt_type='smoothgrad', strides=(3,4,4), sliding_window_shapes=(3,15,15)):
if metric not in self.supported_metrics:
raise Exception(f"Metric {metric} is not supported. Currently {self.supported_metrics} are only supported")
= L([TfmdLists(inp, t) for t in L(ifnone(self.dls.tfms,[None]))])
tls =list(zip(*(tls[0],tls[1])))[0]
inp_data=self._get_enc_dec_data(inp_data)
enc_data,dec_data=self._get_attributions(enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes)
attributionsself._viz(attributions,dec_data,metric)
def _viz(self,attributions,dec_data,metric):
= LinearSegmentedColormap.from_list(self.cmap_name,self.colors, N=self.N)
default_cmap = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
_ 0].numpy(), (1,2,0)),
np.transpose(dec_data[=self.methods,
methods=default_cmap,
cmap=True,
show_colorbar=self.signs,
signs=self.outlier_perc, titles=[f'Original Image - ({dec_data[1]})', metric])
outlier_perc
def _get_enc_dec_data(self,inp_data):
=self.dls.after_item(inp_data)
dec_data=self.dls.after_batch(to_device(self.dls.before_batch(dec_data),self.dls.device))
enc_datareturn(enc_data,dec_data)
def _get_attributions(self,enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes):
# 获取基线
=self.get_baseline_img(enc_data[0],baseline_type)
baseline={}
supported_metrics if metric == 'IG':
self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
return self._int_grads.attribute(enc_data[0],baseline, target=enc_data[1], n_steps=200)
elif metric == 'NT':
self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
self._noise_tunnel= self._noise_tunnel if hasattr(self,'_noise_tunnel') else NoiseTunnel(self._int_grads)
return self._noise_tunnel.attribute(enc_data[0].to(self.dls.device), n_samples=1, nt_type=nt_type, target=enc_data[1])
elif metric == 'Occl':
self._occlusion = self._occlusion if hasattr(self,'_occlusion') else Occlusion(self.model)
return self._occlusion.attribute(enc_data[0].to(self.dls.device),
= strides,
strides =enc_data[1],
target=sliding_window_shapes,
sliding_window_shapes=baseline) baselines
show_doc(CaptumInterpretation)
CaptumInterpretation
CaptumInterpretation (learn, cmap_name='customblue', colors=None, N=256, methods=('original_image','heat_map'), signs=('all','positive'), outlier_perc=1)
Captum Interpretation for Resnet
解释
=CaptumInterpretation(learn)
captum=randint(0,len(fnames))
idx captum.visualize(fnames[idx])
='uniform') captum.visualize(fnames[idx],baseline_type
='gauss') captum.visualize(fnames[idx],baseline_type
='NT',baseline_type='uniform') captum.visualize(fnames[idx],metric
='Occl',baseline_type='gauss') captum.visualize(fnames[idx],metric
Captum 见解回调
@patch
def _formatted_data_iter(x: CaptumInterpretation,dl,normalize_func):
=iter(dl)
dl_iterwhile True:
=next(dl_iter)
images,labels=normalize_func.decode(images).to(dl.device)
imagesyield Batch(inputs=images, labels=labels)
@patch
def insights(x: CaptumInterpretation,inp_data,debug=True):
= lambda o: o*0
_baseline_func= lambda vocab: list(map(str,vocab)) if isinstance(vocab[0],bool) else vocab
_get_vocab = x.dls.test_dl(L(inp_data),with_labels=True, bs=4)
dl = next((func for func in dl.after_batch if type(func)==Normalize),noop)
normalize_func
# captum v0.3 期望输入的张量不包含批次维度。
if nested_attr(normalize_func, 'mean.ndim', 4)==4: normalize_func.mean.squeeze_(0)
if nested_attr(normalize_func, 'std.ndim', 4)==4: normalize_func.std.squeeze_(0)
= AttributionVisualizer(
visualizer =[x.model],
models=lambda o: torch.nn.functional.softmax(o, 1),
score_func=_get_vocab(dl.vocab),
classes=[ImageFeature("Image", baseline_transforms=[_baseline_func], input_transforms=[normalize_func])],
features=x._formatted_data_iter(dl,normalize_func))
dataset=debug) visualizer.render(debug
=CaptumInterpretation(learn)
captum captum.insights(fnames)
完 -