File tree Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 120
120
parser .add_argument ('--reparam' , default = False , action = 'store_true' ,
121
121
help = 'Reparameterize model' )
122
122
parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
123
+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
124
+ help = "torch.compile mode (default: None)." )
123
125
124
126
# codegen (model compilation) options
125
127
scripting_group = parser .add_mutually_exclusive_group ()
@@ -224,6 +226,7 @@ def __init__(
224
226
device = 'cuda' ,
225
227
torchscript = False ,
226
228
torchcompile = None ,
229
+ torchcompile_mode = None ,
227
230
aot_autograd = False ,
228
231
reparam = False ,
229
232
precision = 'float32' ,
@@ -278,7 +281,7 @@ def __init__(
278
281
elif torchcompile :
279
282
assert has_compile , 'A version of torch w/ torch.compile() is required, possibly a nightly.'
280
283
torch ._dynamo .reset ()
281
- self .model = torch .compile (self .model , backend = torchcompile )
284
+ self .model = torch .compile (self .model , backend = torchcompile , mode = torchcompile_mode )
282
285
self .compiled = True
283
286
elif aot_autograd :
284
287
assert has_functorch , "functorch is needed for --aot-autograd"
Original file line number Diff line number Diff line change 114
114
parser .add_argument ('--fuser' , default = '' , type = str ,
115
115
help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
116
116
parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
117
+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
118
+ help = "torch.compile mode (default: None)." )
117
119
118
120
scripting_group = parser .add_mutually_exclusive_group ()
119
121
scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
@@ -216,7 +218,7 @@ def main():
216
218
elif args .torchcompile :
217
219
assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
218
220
torch ._dynamo .reset ()
219
- model = torch .compile (model , backend = args .torchcompile )
221
+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
220
222
elif args .aot_autograd :
221
223
assert has_functorch , "functorch is needed for --aot-autograd"
222
224
model = memory_efficient_fusion (model )
Original file line number Diff line number Diff line change 161
161
help = 'Head initialization scale' )
162
162
group .add_argument ('--head-init-bias' , default = None , type = float ,
163
163
help = 'Head initialization bias value' )
164
+ group .add_argument ('--torchcompile-mode' , type = str , default = None ,
165
+ help = "torch.compile mode (default: None)." )
164
166
165
167
# scripting / codegen
166
168
scripting_group = group .add_mutually_exclusive_group ()
@@ -627,7 +629,7 @@ def main():
627
629
if args .torchcompile :
628
630
# torch compile should be done after DDP
629
631
assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
630
- model = torch .compile (model , backend = args .torchcompile )
632
+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
631
633
632
634
# create the train and eval datasets
633
635
if args .data and not args .data_dir :
Original file line number Diff line number Diff line change 139
139
parser .add_argument ('--reparam' , default = False , action = 'store_true' ,
140
140
help = 'Reparameterize model' )
141
141
parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
142
-
142
+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
143
+ help = "torch.compile mode (default: None)." )
143
144
144
145
scripting_group = parser .add_mutually_exclusive_group ()
145
146
scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
@@ -246,7 +247,7 @@ def validate(args):
246
247
elif args .torchcompile :
247
248
assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
248
249
torch ._dynamo .reset ()
249
- model = torch .compile (model , backend = args .torchcompile )
250
+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
250
251
elif args .aot_autograd :
251
252
assert has_functorch , "functorch is needed for --aot-autograd"
252
253
model = memory_efficient_fusion (model )
You can’t perform that action at this time.
0 commit comments