The v0.8.0 release of Captum offers new influence functions for data attribution, improvements to feature attribution methods (including LLM prompt attribution), enhanced type annotations for modern Python type checking, and a variety of other small changes. Note that support for Python 3.8 and PyTorch 1.10 have been dropped, and Captum Insights will be deprecated next major release.
Data Attribution: New Influence Functions
This version offers two different implementations that both calculate the "infinitesimal" influence score as defined in the paper "Understanding Black-box Predictions via Influence Functions".
NaiveInfluenceFunction
: a computationally slow but exact implementation that is useful for obtaining "ground-truth" (though, note that influence scores themselves are an approximation of the effect of removing then retraining). Several papers actually use this approach, i.e. "Learning Augmentation Network via Influence Functions", "Quantifying and Mitigating the Impact of Label Errors on Model Disparity Metrics", "Achieving Fairness at No Utility Cost via Data Reweighting with Influence" (PR #1214)ArnoldiInfluenceFunction
: This is a computationally efficient implementation described in the paper "Scaling Up Influence Functions" by Schioppa et al. (PR #1187)
Example:
from captum.influence._core.influence_function import NaiveInfluenceFunction
from torch import nn
from torch.utils.data import DataLoader
train_dl = DataLoader(your_dataset, batch_size=8) # your dataloader
criterion = nn.MSELoss(reduction="none")
influence = NaiveInfluenceFunction(
net,
train_dl,
checkpoint_path, # path to your model checkpoint
loss_fn=criterion,
batch_size=batch_size,
)
# compute pairwise influences using influence implementation
influence_train_test_influences = influence.influence(
(test_samples, test_labels) # your test data (Tensors)
)
What is the "infinitesimal" influence score
More details on the "infinitesimal" influence score: This "infinitesimal" influence score approximately answers the question if a given training example were infinitesimally down-weighted and the model re-trained to optimality, how much would the loss on a given test example change. Mathematically, the aforementioned influence score is given by
What the two implementations have in common
Both implementations compute a low-rank approximation of the inverse Hessian, i.e. a tall and skinny (with width k) matrix
This approximation is useful for several reasons:
- It avoids numerical issues associated with inverting small eigenvalues
- Since the influence score is given by
$\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)$ , which is approximated by$(\nabla_\theta L(x)' R) (\nabla_\theta L(z)' R)$ , we can compute an "influence embedding" for a given example x,$\nabla_\theta L(x)' R$ , such that the influence score of one example on another is approximately the dot-product of their respective embeddings. Because k is small, i.e. 50, these influence embeddings are low-dimensional. - Even for large models, we can store
$R$ in memory, provided k is small. This means influence embeddings (and thus influence scores) can be efficiently computed by doing a backwards pass to compute$\nabla_\theta L(x)$ and then multiplying by$R'$ . This is orders of magnitude faster than the previous LISSA approach of Koh et al, which to compute the influence score involving a given example, need to compute Hessian-vector products involving on the order of 10^4 examples.
The implementations differ in how they compute the top-k eigenvalues / eigenvectors.
How NaiveInfluenceFunction computes the top-k eigenvalues / eigenvectors
It is "naive" in that it computes the top-k eigenvalues / eigenvectors by explicitly forming the Hessian, converting it to a 2D tensor, computing its eigenvectors / eigenvalues, and then sorting. See documentation of the _set_projections_naive_influence_function
method for more details.
How ArnoldiInfluenceFunction computes the top-k eigenvalues / eigenvectors
The key novelty of the approach by Schioppa et al. is that it uses the Arnoldi iteration to find the top-k eigenvalues / eigenvectors of the Hessian without explicitly forming the Hessian. In more detail, the approach first runs the Arnoldi iteration, which only requires the ability to compute Hessian-vector products, to find a Krylov subspace of moderate dimension, i.e. 200. It then finds the top-k eigenvalues / eigenvectors of the restriction of the Hessian to the subspace, where k is small, i.e. 50. Finally, it expresses the eigenvectors in the original basis. This approach for finding the top-k eigenvalues / eigenvectors is justified by the property of the Arnoldi iteration, that the Krylov subspace it returns tends to contain the top eigenvectors.
This implementation does incur some one-time overhead in __init__
, where it runs the Arnoldi iteration to calculate
Unlike NaiveInfluenceFunction
, this implementation does not flatten any parameters, as the 2D Hessian is never formed, and Pytorch's Hessian-vector implementation (torch.autograd.functional.hvp
) allows the input and output vector to be a tuple of tensors. Avoiding flattening / unflattening parameters brings scalability gains.
Feature Attribution Improvements
-
Added initial support for asynchronous attribution (PyTorch futures) for the following methods (PRs #1295, #1316, #1317, #1314, #1320, #1326, #1335, #1487):
- FeatureAblation
- FeaturePermutation
- ShapleyValueSampling
- ShapleyValues
-
Added support for additional gradient-based LLM attribution methods (PRs #1337, #1420):
- LayerGradientXActivation
- LayerGradientShap
-
Added support to perturbation-based LLM attribution for “key and value” caching (PRs #1224, #1341, #1343, #1353)
-
Added support to pass gradient keyword arguments to the following Captum.attr methods through grad_kwargs (PRs #1286, #1294, #1435):
- LayerGradCam
- InternalInfluence
- LayerConductance
- LayerDeepLift
- LayerGradientShap
- NeuronConductance
- LayerGradientXActivation
- LayerIntegratedGradients
-
Added a tutorial for perturbation- and gradient-based LLM attribution (tutorials/Llama2_LLM_Attribution.ipynb) (PRs #1228, #1333, #1445)
Changes to Requirements
- We have dropped support for Python < 3.8 and PyTorch < 1.10 (PRs #1460, #1298, #1305)
- We plan to deprecate Captum Insights in the next major release (PR #1498)
Improvements to Type Annotations
Greatly improved typing throughout the library, now supporting and complying with the latest versions of both pyre and mypy type checking (PRs #1371, #1318, #1319, #1324, #1247, #1270, #1299, #1330, #1356, #1359, #1377, #1389, #1381, #1382, #1383, #1406, #1405, #1404, #1403, #1402, #1401, #1400, #1399, #1398, #1397, #1396, #1395, #1394, #1393, #1392, #1391, #1390, #1385, #1412, #1409, #1411, #1418, #1416, #1415, #1414, #1421, #1424, #1365, #1427, #1425, #1428, #1433, #1434, #1431, #1437, #1438, #1439, #1441, #1448, #1453, #1455, #1459, #1457, #1458, #1461, #1462, #1463, #1464, #1465, #1466, #1467, #1469, #1470, #1471, #1472, #1474, #1475, #1476, #1477, #1479, #1480, #1481, #1482, #1503, #1502)
Minor Changes and Fixes
- Added a fix to IntegratedGradients to fully support the MPS backend (PR #1227)
- Added support for the latest version of the black code formatter (PR #1241)
- Improved the test case coverage, logic, stability, and speed across Captum, especially for layer-based attribution methods, LLM attribution, and captum.influence methods and utilities (PRs #1250, #1251, #1253, #1258, #1243, #1249, #1252, #1259, #1260, #1262, #1264, #1265, #1272, #1300, #1301, #1302, #1323, #1352, #1362, #1364, #1388, #1408, #1410, #1419, #1422, #1436, #1454, #1484, #1485, #1492)
- Improved LLM attribution plotting aesthetics and text readability (PRs #1348, #1349, #1351, #1354, #1355, #1360, #1417)
- Free autograd graphs in between LLM attribution calls (PR #1347)
- Fixed data type bug with the titanic tutorial (tutorials/Titanic_Basic_Interpret.ipynb) (PR #1331)
- Fixed multiple device-related bugs for feature ablation/permutation masks and LLM attribution (PR #1245, #1307)
- Reduced the complexity of various functions throughout Captum (PRs #1368, #1372, #1369, #1370, #1374, #1375, #1376, #1378, #1380, #1384, #1407)
- Fixed a bug in the tutorial parsing script (PR #1268)