21
21
from typing import Callable , List , Tuple , Union
22
22
23
23
24
- from timm .models import is_model , list_models , get_pretrained_cfg
24
+ from timm .models import is_model , list_models , get_pretrained_cfg , get_arch_pretrained_cfgs
25
25
26
26
27
27
parser = argparse .ArgumentParser (description = 'Per-model process launcher' )
@@ -98,23 +98,44 @@ def _get_model_cfgs(
98
98
num_classes = None ,
99
99
expand_train_test = False ,
100
100
include_crop = True ,
101
+ expand_arch = False ,
101
102
):
102
- model_cfgs = []
103
- for n in model_names :
104
- pt_cfg = get_pretrained_cfg (n )
105
- if num_classes is not None and getattr (pt_cfg , 'num_classes' , 0 ) != num_classes :
106
- continue
107
- model_cfgs .append ((n , pt_cfg .input_size [- 1 ], pt_cfg .crop_pct ))
108
- if expand_train_test and pt_cfg .test_input_size is not None :
109
- if pt_cfg .test_crop_pct is not None :
110
- model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .test_crop_pct ))
103
+ model_cfgs = set ()
104
+
105
+ for name in model_names :
106
+ if expand_arch :
107
+ pt_cfgs = get_arch_pretrained_cfgs (name ).values ()
108
+ else :
109
+ pt_cfg = get_pretrained_cfg (name )
110
+ pt_cfgs = [pt_cfg ] if pt_cfg is not None else []
111
+
112
+ for cfg in pt_cfgs :
113
+ if cfg .input_size is None :
114
+ continue
115
+ if num_classes is not None and getattr (cfg , 'num_classes' , 0 ) != num_classes :
116
+ continue
117
+
118
+ # Add main configuration
119
+ size = cfg .input_size [- 1 ]
120
+ if include_crop :
121
+ model_cfgs .add ((name , size , cfg .crop_pct ))
111
122
else :
112
- model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .crop_pct ))
123
+ model_cfgs .add ((name , size ))
124
+
125
+ # Add test configuration if required
126
+ if expand_train_test and cfg .test_input_size is not None :
127
+ test_size = cfg .test_input_size [- 1 ]
128
+ if include_crop :
129
+ test_crop = cfg .test_crop_pct or cfg .crop_pct
130
+ model_cfgs .add ((name , test_size , test_crop ))
131
+ else :
132
+ model_cfgs .add ((name , test_size ))
133
+
134
+ # Format the output
113
135
if include_crop :
114
- model_cfgs = [(n , {'img-size' : r , 'crop-pct' : cp }) for n , r , cp in sorted (model_cfgs )]
136
+ return [(n , {'img-size' : r , 'crop-pct' : cp }) for n , r , cp in sorted (model_cfgs )]
115
137
else :
116
- model_cfgs = [(n , {'img-size' : r }) for n , r , cp in sorted (model_cfgs )]
117
- return model_cfgs
138
+ return [(n , {'img-size' : r }) for n , r in sorted (model_cfgs )]
118
139
119
140
120
141
def main ():
@@ -132,17 +153,16 @@ def main():
132
153
model_cfgs = _get_model_cfgs (model_names , num_classes = 1000 , expand_train_test = True )
133
154
elif args .model_list == 'all_res' :
134
155
model_names = list_models ()
135
- model_cfgs = _get_model_cfgs (model_names , expand_train_test = True , include_crop = False )
156
+ model_cfgs = _get_model_cfgs (model_names , expand_train_test = True , include_crop = False , expand_arch = True )
136
157
elif not is_model (args .model_list ):
137
158
# model name doesn't exist, try as wildcard filter
138
159
model_names = list_models (args .model_list )
139
160
model_cfgs = [(n , None ) for n in model_names ]
140
161
141
162
if not model_cfgs and os .path .exists (args .model_list ):
142
163
with open (args .model_list ) as f :
143
- model_cfgs = []
144
164
model_names = [line .rstrip () for line in f ]
145
- _get_model_cfgs (
165
+ model_cfgs = _get_model_cfgs (
146
166
model_names ,
147
167
#num_classes=1000,
148
168
expand_train_test = True ,
0 commit comments