预测解释

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.data.all import *
from fastai.optimizer import *
from fastai.learner import *
from fastai.tabular.core import *
import sklearn.metrics as skm
from fastai.test_utils import *
from nbdev.showdoc import *

创建对象的类以更好地解释模型的预测。

from fastai.vision.all import *
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
                  get_items=get_image_files, 
                  splitter=RandomSubsetSplitter(.1,.1, seed=42),
                  get_y=parent_label)
test_dls = mnist.dataloaders(untar_data(URLs.MNIST_SAMPLE), bs=8)
test_learner = vision_learner(test_dls, resnet18)
@typedispatch
def plot_top_losses(x, y, *args, **kwargs):
    raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")
_all_ = ["plot_top_losses"]
class Interpretation():
    "Interpretation base class, can be inherited for task specific Interpretation classes"
    def __init__(self,
        learn:Learner,
        dl:DataLoader, # `DataLoader` 用于运行推理
        losses:TensorBase, # 从`dl`计算的损失
        act=None # 预测激活函数
    ): 
        store_attr()

    def __getitem__(self, idxs):
        "Return inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]
        tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))
        inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False, 
                                                        with_decoded=True, act=self.act, reorder=False)
        return inps, preds, targs, decoded, self.losses[idxs]

    @classmethod
    def from_learner(cls,
        learn, # 用于生成解释的模型
        ds_idx:int=1, # 当`dl`为`None`时,`learn.dls`的索引
        dl:DataLoader=None, # `Dataloader` 用于进行预测
        act=None # 覆盖默认值或设置预测激活函数
    ):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
        _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,
                                     with_preds=False, with_targs=False, act=act)
        return cls(learn, dl, losses, act)

    def top_losses(self,
        k:int|None=None, # 返回 `k` 个损失,默认为全部
        largest:bool=True, # 按最大或最小排序损失
        items:bool=False # 是否返回输入项
    ):
        "`k` largest(/smallest) losses and indexes, defaulting to all losses."
        losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
        if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]
        else:     return losses, idx

    def plot_top_losses(self,
        k:int|MutableSequence, # 绘制损失次数
        largest:bool=True, # 按最大或最小排序损失
        **kwargs
    ):
        "Show `k` largest(/smallest) preds and losses. Implementation based on type dispatch"
        if is_listy(k) or isinstance(k, range):
            losses, idx = (o[k] for o in self.top_losses(None, largest))
        else: 
            losses, idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)
        x, y, its = self.dl._pre_show_batch(inps+targs, max_n=len(idx))
        x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)
        #待办事项:确定是否需要此项
        #它的None表示一个批次知道如何整体展示自己,因此我们传递x, x1
        #否则:显示结果(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)

    def show_results(self,
        idxs:list, # 预测指标与目标指标
        **kwargs
    ):
        "Show predictions and targets of `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        inps, _, targs, decoded, _ = self[idxs]
        b = tuplify(inps)+tuplify(targs)
        self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)
show_doc(Interpretation, title_level=3)

class Interpretation[source]

Interpretation(learn:Learner, dl:DataLoader, losses:TensorBase, act=None)

Interpretation base class, can be inherited for task specific Interpretation classes

Type Default Details
learn Learner No Content
dl DataLoader DataLoader to run inference over
losses TensorBase Losses calculated from dl
act NoneType None Activation function for prediction

Interpretation 是一个用于探索训练模型预测的辅助基类。它可以被继承以用于特定任务的解释类,例如 ClassificationInterpretationInterpretation 具有内存高效性,并应能够处理任何大小的数据集,前提是硬件能够训练相同的模型。

Note

Interpretation 通过实时生成每个项目的输入、预测、目标、解码输出和损失,并尽可能使用批处理,因此具有内存高效性。

show_doc(Interpretation.from_learner, title_level=3)

Interpretation.from_learner[source]

Interpretation.from_learner(learn, ds_idx:int=1, dl:DataLoader=None, act=None)

Construct interpretation object from a learner

Type Default Details
learn Model used to create interpretation
ds_idx int 1 Index of learn.dls when dl is None
dl DataLoader None Dataloader used to make predictions
act NoneType None Override default or set prediction activation function
show_doc(Interpretation.top_losses, title_level=3)

Interpretation.top_losses[source]

Interpretation.top_losses(k:(<class 'int'>, None)=None, largest:bool=True, items:bool=False)

k largest(/smallest) losses and indexes, defaulting to all losses.

Type Default Details
k (int, None) None Return k losses, defaults to all
largest bool True Sort losses by largest or smallest
items bool False Whether to return input items

默认情况下,k=Nonetop_losses 将返回整个数据集的损失。top_losses 可以选择性地包括每个损失的输入项目,通常是文件路径或 Pandas DataFrame

show_doc(Interpretation.plot_top_losses, title_level=3)

Interpretation.plot_top_losses[source]

Interpretation.plot_top_losses(k:(<class 'int'>, <class 'list'>), largest:bool=True, **kwargs)

Show k largest(/smallest) preds and losses. Implementation based on type dispatch

Type Default Details
k (int, list) Number of losses to plot
largest bool True Sort losses by largest or smallest
kwargs No Content

要绘制前9个最大的损失:

interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)

然后绘制第7到第16个最大的损失:

interp.plot_top_losses(range(7,16))
show_doc(Interpretation.show_results, title_level=3)

Interpretation.show_results[source]

Interpretation.show_results(idxs:list, **kwargs)

Show predictions and targets of idxs

Type Default Details
idxs list Indices of predictions and targets
kwargs No Content

Learner.show_results,但可以传递所需的索引或多个索引,以显示结果的项目。

interp = Interpretation.from_learner(test_learner)
x, y, out = [], [], []
for batch in test_learner.dls.valid:
    x += batch[0]
    y += batch[1]
    out += test_learner.model(batch[0])
x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)
inps, preds, targs, decoded, losses = interp[:]
test_eq(inps, to_cpu(x))
test_eq(targs, to_cpu(y))
loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)
test_close(losses, to_cpu(loss))
# 验证存储的损失等于为索引计算的损失
top_losses, idx = interp.top_losses(9)

dl = test_learner.dls[1].new(shuffle=False, drop_last=False)
items = getattr(dl.items, 'iloc', L(dl.items))[idx]
tmp_dl = test_learner.dls.test_dl(items, with_labels=True, process=not isinstance(dl, TabDataLoader))
_, _, _, _, losses = test_learner.get_preds(dl=tmp_dl, with_input=True, with_loss=True, 
                                            with_decoded=True, act=None, reorder=False)

test_close(top_losses, losses, 1e-2)
#虚拟测试以确保我们可以在训练集上运行
interp = Interpretation.from_learner(test_learner, ds_idx=0)
x, y, out = [], [], []
for batch in test_learner.dls.train.new(drop_last=False, shuffle=False):
    x += batch[0]
    y += batch[1]
    out += test_learner.model(batch[0])
x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)
inps, preds, targs, decoded, losses = interp[:]
test_eq(inps, to_cpu(x))
test_eq(targs, to_cpu(y))
loss = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)
test_close(losses, to_cpu(loss))
class ClassificationInterpretation(Interpretation):
    "Interpretation methods for classification models."

    def __init__(self, 
        learn:Learner, 
        dl:DataLoader, # `DataLoader` 用于运行推理
        losses:TensorBase, # 从`dl`计算的损失
        act=None # 预测激活函数
    ):
        super().__init__(learn, dl, losses, act)
        self.vocab = self.dl.vocab
        if is_listy(self.vocab): self.vocab = self.vocab[-1]

    def confusion_matrix(self):
        "Confusion matrix as an `np.ndarray`."
        x = torch.arange(0, len(self.vocab))
        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, 
                                               with_targs=True, act=self.act)
        d,t = flatten_check(decoded, targs)
        cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
        return to_np(cm)

    def plot_confusion_matrix(self, 
        normalize:bool=False, # 是否将事件正常化
        title:str='Confusion matrix', # 情节标题
        cmap:str="Blues", # 来自matplotlib的色图
        norm_dec:int=2, # 归一化出现次数的小数位数
        plot_txt:bool=True, # 矩阵中的显示出现
        **kwargs
    ):
        "Plot the confusion matrix, with `title` and using `cmap`."
        # 此功能主要复制自sklearn文档。
        cm = self.confusion_matrix()
        if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fig = plt.figure(**kwargs)
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        tick_marks = np.arange(len(self.vocab))
        plt.xticks(tick_marks, self.vocab, rotation=90)
        plt.yticks(tick_marks, self.vocab, rotation=0)

        if plot_txt:
            thresh = cm.max() / 2.
            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
                plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                         if cm[i, j] > thresh else "black")

        ax = fig.gca()
        ax.set_ylim(len(self.vocab)-.5,-.5)

        plt.tight_layout()
        plt.ylabel('Actual')
        plt.xlabel('Predicted')
        plt.grid(False)

    def most_confused(self, min_val=1):
        "Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences"
        cm = self.confusion_matrix()
        np.fill_diagonal(cm, 0)
        res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))]
        return sorted(res, key=itemgetter(2), reverse=True)

    def print_classification_report(self):
        "Print scikit-learn classification report"
        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, 
                                               with_targs=True, act=self.act)
        d,t = flatten_check(decoded, targs)
        names = [str(v) for v in self.vocab]
        print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=names))
show_doc(ClassificationInterpretation.confusion_matrix, title_level=3)

ClassificationInterpretation.confusion_matrix[source]

ClassificationInterpretation.confusion_matrix()

Confusion matrix as an np.ndarray.

show_doc(ClassificationInterpretation.plot_confusion_matrix, title_level=3)

ClassificationInterpretation.plot_confusion_matrix[source]

ClassificationInterpretation.plot_confusion_matrix(normalize=False, title='Confusion matrix', cmap='Blues', norm_dec=2, plot_txt=True, **kwargs)

Plot the confusion matrix, with title and using cmap.

show_doc(ClassificationInterpretation.most_confused, title_level=3)

ClassificationInterpretation.most_confused[source]

ClassificationInterpretation.most_confused(min_val=1)

Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences

# 简单的测试以确保 ClassificationInterpretation 正常工作
interp = ClassificationInterpretation.from_learner(test_learner)
cm = interp.confusion_matrix()
class SegmentationInterpretation(Interpretation):
    "Interpretation methods for segmentation models."
    pass

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.