可解释性/可解释性分析
PyTorch Tabular 中的可解释性功能允许用户解释和理解表格深度学习模型所做出的预测。这些功能提供了对模型决策过程的洞察,并帮助识别最具影响力的特征。一些可解释性功能是模型内置的,而许多其他功能则基于 Captum 库。
原生特征重要性¶
GBDT 模型中每个人都喜欢的一个功能是特征重要性。它帮助我们理解哪些特征对模型最为重要。PyTorch Tabular 为一些模型(如 GANDALF、GATE 和 FTTransformers)提供了类似的功能,这些模型原生支持特征重要性的提取。
局部特征归因/解释¶
局部特征归因/解释帮助我们理解每个特征对特定样本预测的贡献。PyTorch Tabular 为除 TabTransformer、Tabnet 和 Mixed Density Networks 之外的所有模型提供了此功能。它基于 Captum 库。该库提供了许多用于计算特征归因的算法。PyTorch Tabular 提供了一个围绕该库的包装器,使其易于使用。支持以下算法:
- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html
PyTorch Tabular
还支持解释单个实例以及一批实例。但是,较大的数据集将需要更长的时间来解释。例外的是 FeaturePermutation
和 FeatureAblation
方法,它们仅对大批实例有意义。
大多数这些可解释性方法需要一个基线。这用于将输入的归因与基线的归因进行比较。基线可以是标量值、与输入形状相同的张量,或者是特殊字符串,如 "b|10000",表示从训练数据中抽取 10000 个样本。如果未提供基线,则使用默认基线(零)。
# tabular_model 是支持模型的训练模型
# 使用 GradientShap 方法和基线为 10000 个训练数据样本解释单个实例
tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000")
# 使用 IntegratedGradients 方法和基线为 0 解释一批实例
tabular_model.explain(test.head(10), method="IntegratedGradients", baselines=0)
查看 Captum 文档 以获取有关算法的更多详细信息,并查看 可解释性教程 以获取示例用法。
API 参考¶
pytorch_tabular.TabularModel.explain(data, method='GradientShap', method_args={}, baselines=None, **kwargs)
¶
返回模型的特征归因/解释,以pandas DataFrame的形式呈现.返回的数据框形状为(样本数量, 特征数量)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
DataFrame
|
需要解释的数据框 |
required |
method
|
str
|
用于解释模型的方法. 应为以下默认值之一:"GradientShap". 更多详情,请参考 https://captum.ai/api/attribution.html |
'GradientShap'
|
method_args
|
Optional[Dict]
|
传递给Captum方法初始化的参数. |
{}
|
baselines
|
Union[float, tensor, str]
|
用于解释的基线.
如果提供标量,将使用该值作为所有特征的基线.
如果提供张量,将使用该张量作为所有特征的基线.
如果提供类似 |
None
|
**kwargs
|
传递给Captum方法 |
{}
|
Returns:
Name | Type | Description |
---|---|---|
DataFrame |
DataFrame
|
包含特征重要性的数据框 |
Source code in src/pytorch_tabular/tabular_model.py
1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 |
|