Skip to content

Commit e4a0de9

Browse files
authored
4320 update docstrings for gradient based saliency (#5268)
Signed-off-by: Wenqi Li <[email protected]> Fixes #4320 it seems we also have sufficient modules to close #3542 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent c68e922 commit e4a0de9

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

docs/source/visualize.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@ Occlusion sensitivity
2525
.. automodule:: monai.visualize.occlusion_sensitivity
2626
:members:
2727

28+
Gradient-based saliency maps
29+
----------------------------
30+
31+
.. automodule:: monai.visualize.gradient_based
32+
:members:
33+
34+
2835
Utilities
2936
---------
37+
3038
.. automodule:: monai.visualize.utils
3139
:members:

monai/visualize/gradient_based.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,29 @@ def backward(ctx, grad_output):
4545

4646

4747
class _GradReLU(torch.nn.Module):
48+
"""
49+
A customized ReLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).
50+
"""
51+
4852
def forward(self, x: torch.Tensor) -> torch.Tensor:
4953
out: torch.Tensor = _AutoGradReLU.apply(x)
5054
return out
5155

5256

5357
class VanillaGrad:
58+
"""
59+
Given an input image ``x``, calling this class will perform the forward pass, then set to zero
60+
all activations except one (defined by ``index``) and propagate back to the image to achieve a gradient-based
61+
saliency map.
62+
63+
If ``index`` is None, argmax of the output logits will be used.
64+
65+
See also:
66+
67+
- Simonyan et al. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
68+
(https://arxiv.org/abs/1312.6034)
69+
"""
70+
5471
def __init__(self, model: torch.nn.Module) -> None:
5572
if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary
5673
self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True)
@@ -83,7 +100,11 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **k
83100

84101
class SmoothGrad(VanillaGrad):
85102
"""
103+
Compute averaged sensitivity map based on ``n_samples`` (Gaussian additive) of noisy versions
104+
of the input image ``x``.
105+
86106
See also:
107+
87108
- Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825
88109
"""
89110

@@ -126,12 +147,26 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **k
126147

127148

128149
class GuidedBackpropGrad(VanillaGrad):
150+
"""
151+
Based on Springenberg and Dosovitskiy et al. https://arxiv.org/abs/1412.6806,
152+
compute gradient-based saliency maps by backpropagating positive graidents and inputs (see ``_AutoGradReLU``).
153+
154+
See also:
155+
156+
- Springenberg and Dosovitskiy et al. Striving for Simplicity: The All Convolutional Net
157+
(https://arxiv.org/abs/1412.6806)
158+
"""
159+
129160
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
130161
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
131162
return super().__call__(x, index, **kwargs)
132163

133164

134165
class GuidedBackpropSmoothGrad(SmoothGrad):
166+
"""
167+
Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``.
168+
"""
169+
135170
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
136171
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
137172
return super().__call__(x, index, **kwargs)

0 commit comments

Comments
 (0)