Skip to content

Commit f81cbdc

Browse files
authored
Merge pull request #2274 from huggingface/bulk_runner_tweaks
Better all res resolution for bulk runner
2 parents 6c42299 + fa4a1e5 commit f81cbdc

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

bulk_runner.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Callable, List, Tuple, Union
2222

2323

24-
from timm.models import is_model, list_models, get_pretrained_cfg
24+
from timm.models import is_model, list_models, get_pretrained_cfg, get_arch_pretrained_cfgs
2525

2626

2727
parser = argparse.ArgumentParser(description='Per-model process launcher')
@@ -98,23 +98,44 @@ def _get_model_cfgs(
9898
num_classes=None,
9999
expand_train_test=False,
100100
include_crop=True,
101+
expand_arch=False,
101102
):
102-
model_cfgs = []
103-
for n in model_names:
104-
pt_cfg = get_pretrained_cfg(n)
105-
if num_classes is not None and getattr(pt_cfg, 'num_classes', 0) != num_classes:
106-
continue
107-
model_cfgs.append((n, pt_cfg.input_size[-1], pt_cfg.crop_pct))
108-
if expand_train_test and pt_cfg.test_input_size is not None:
109-
if pt_cfg.test_crop_pct is not None:
110-
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.test_crop_pct))
103+
model_cfgs = set()
104+
105+
for name in model_names:
106+
if expand_arch:
107+
pt_cfgs = get_arch_pretrained_cfgs(name).values()
108+
else:
109+
pt_cfg = get_pretrained_cfg(name)
110+
pt_cfgs = [pt_cfg] if pt_cfg is not None else []
111+
112+
for cfg in pt_cfgs:
113+
if cfg.input_size is None:
114+
continue
115+
if num_classes is not None and getattr(cfg, 'num_classes', 0) != num_classes:
116+
continue
117+
118+
# Add main configuration
119+
size = cfg.input_size[-1]
120+
if include_crop:
121+
model_cfgs.add((name, size, cfg.crop_pct))
111122
else:
112-
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.crop_pct))
123+
model_cfgs.add((name, size))
124+
125+
# Add test configuration if required
126+
if expand_train_test and cfg.test_input_size is not None:
127+
test_size = cfg.test_input_size[-1]
128+
if include_crop:
129+
test_crop = cfg.test_crop_pct or cfg.crop_pct
130+
model_cfgs.add((name, test_size, test_crop))
131+
else:
132+
model_cfgs.add((name, test_size))
133+
134+
# Format the output
113135
if include_crop:
114-
model_cfgs = [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
136+
return [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
115137
else:
116-
model_cfgs = [(n, {'img-size': r}) for n, r, cp in sorted(model_cfgs)]
117-
return model_cfgs
138+
return [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
118139

119140

120141
def main():
@@ -132,17 +153,16 @@ def main():
132153
model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True)
133154
elif args.model_list == 'all_res':
134155
model_names = list_models()
135-
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False)
156+
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False, expand_arch=True)
136157
elif not is_model(args.model_list):
137158
# model name doesn't exist, try as wildcard filter
138159
model_names = list_models(args.model_list)
139160
model_cfgs = [(n, None) for n in model_names]
140161

141162
if not model_cfgs and os.path.exists(args.model_list):
142163
with open(args.model_list) as f:
143-
model_cfgs = []
144164
model_names = [line.rstrip() for line in f]
145-
_get_model_cfgs(
165+
model_cfgs = _get_model_cfgs(
146166
model_names,
147167
#num_classes=1000,
148168
expand_train_test=True,

timm/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,5 @@
9595
from ._prune import adapt_model_from_string
9696
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
9797
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
98-
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
98+
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value, \
99+
get_arch_pretrained_cfgs

timm/models/_registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
__all__ = [
1717
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
1818
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
19-
'get_pretrained_cfg_value', 'is_model_pretrained'
19+
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_pretrained_cfgs_for_arch'
2020
]
2121

2222
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
@@ -341,3 +341,12 @@ def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
341341
"""
342342
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
343343
return getattr(cfg, cfg_key, None)
344+
345+
346+
def get_arch_pretrained_cfgs(model_name: str) -> Dict[str, PretrainedCfg]:
347+
""" Get all pretrained cfgs for a given architecture.
348+
"""
349+
arch_name, _ = split_model_name_tag(model_name)
350+
model_names = _model_with_tags[arch_name]
351+
cfgs = {m: _model_pretrained_cfgs[m] for m in model_names}
352+
return cfgs

0 commit comments

Comments
 (0)