Skip to content

Commit 07535f4

Browse files
committed
Add AttentionExtract helper module
1 parent 45b7ae8 commit 07535f4

File tree

5 files changed

+96
-5
lines changed

5 files changed

+96
-5
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
set_pretrained_download_progress, set_pretrained_check_hash
8181
from ._factory import create_model, parse_model_name, safe_model_name
8282
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
83-
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
83+
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
8484
register_notrace_module, is_notrace_module, get_notrace_modules, \
8585
register_notrace_function, is_notrace_function, get_notrace_functions
8686
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint

timm/models/_features.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class FeatureHooks:
158158

159159
def __init__(
160160
self,
161-
hooks: Sequence[str],
161+
hooks: Sequence[Union[str, Dict]],
162162
named_modules: dict,
163163
out_map: Sequence[Union[int, str]] = None,
164164
default_hook_type: str = 'forward',
@@ -168,11 +168,13 @@ def __init__(
168168
self._handles = []
169169
modules = {k: v for k, v in named_modules}
170170
for i, h in enumerate(hooks):
171-
hook_name = h['module']
171+
hook_name = h if isinstance(h, str) else h['module']
172172
m = modules[hook_name]
173173
hook_id = out_map[i] if out_map else hook_name
174174
hook_fn = partial(self._collect_output_hook, hook_id)
175-
hook_type = h.get('hook_type', default_hook_type)
175+
hook_type = default_hook_type
176+
if isinstance(h, dict):
177+
hook_type = h.get('hook_type', default_hook_type)
176178
if hook_type == 'forward_pre':
177179
handle = m.register_forward_pre_hook(hook_fn)
178180
elif hook_type == 'forward':

timm/models/_features_fx.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from ._features import _get_feature_info, _get_return_layers
1010

1111
try:
12+
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
1213
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
14+
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
1315
has_fx_feature_extraction = True
1416
except ImportError:
1517
has_fx_feature_extraction = False
@@ -30,7 +32,7 @@
3032

3133
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
3234
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
33-
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
35+
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
3436

3537

3638
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
@@ -92,6 +94,13 @@ def get_notrace_functions():
9294
return list(_autowrap_functions)
9395

9496

97+
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
98+
return _get_graph_node_names(
99+
model,
100+
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
101+
)
102+
103+
95104
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
96105
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
97106
return _create_feature_extractor(

timm/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .agc import adaptive_clip_grad
2+
from .attention_extract import AttentionExtract
23
from .checkpoint_saver import CheckpointSaver
34
from .clip_grad import dispatch_clip_grad
45
from .cuda import ApexScaler, NativeScaler

timm/utils/attention_extract.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import fnmatch
2+
from collections import OrderedDict
3+
from typing import Union, Optional, List
4+
5+
import torch
6+
7+
8+
class AttentionExtract(torch.nn.Module):
9+
# defaults should cover a significant number of timm models with attention maps.
10+
default_node_names = ['*attn.softmax']
11+
default_module_names = ['*attn_drop']
12+
13+
def __init__(
14+
self,
15+
model: Union[torch.nn.Module],
16+
names: Optional[List[str]] = None,
17+
mode: str = 'eval',
18+
method: str = 'fx',
19+
hook_type: str = 'forward',
20+
):
21+
""" Extract attention maps (or other activations) from a model by name.
22+
23+
Args:
24+
model: Instantiated model to extract from.
25+
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
26+
mode: 'train' or 'eval' model mode.
27+
method: 'fx' or 'hook' extraction method.
28+
hook_type: 'forward' or 'forward_pre' hooks used.
29+
"""
30+
super().__init__()
31+
assert mode in ('train', 'eval')
32+
if mode == 'train':
33+
model = model.train()
34+
else:
35+
model = model.eval()
36+
37+
assert method in ('fx', 'hook')
38+
if method == 'fx':
39+
# names are activation node names
40+
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
41+
42+
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
43+
matched = []
44+
names = names or self.default_node_names
45+
for n in names:
46+
matched.extend(fnmatch.filter(node_names, n))
47+
if not matched:
48+
raise RuntimeError(f'No node names found matching {names}.')
49+
50+
self.model = GraphExtractNet(model, matched)
51+
self.hooks = None
52+
else:
53+
# names are module names
54+
assert hook_type in ('forward', 'forward_pre')
55+
from timm.models._features import FeatureHooks
56+
57+
module_names = [n for n, m in model.named_modules()]
58+
matched = []
59+
names = names or self.default_module_names
60+
for n in names:
61+
matched.extend(fnmatch.filter(module_names, n))
62+
if not matched:
63+
raise RuntimeError(f'No module names found matching {names}.')
64+
65+
self.model = model
66+
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
67+
68+
self.names = matched
69+
self.mode = mode
70+
self.method = method
71+
72+
def forward(self, x):
73+
if self.hooks is not None:
74+
self.model(x)
75+
output = self.hooks.get_output(device=x.device)
76+
else:
77+
output = self.model(x)
78+
output = OrderedDict(zip(self.names, output))
79+
return output

0 commit comments

Comments
 (0)