Skip to content

Commit 4d4bdd6

Browse files
committed
Add --torchcompile-mode args to train, validation, inference, benchmark scripts
1 parent 14d55a7 commit 4d4bdd6

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120
parser.add_argument('--reparam', default=False, action='store_true',
121121
help='Reparameterize model')
122122
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
123+
parser.add_argument('--torchcompile-mode', type=str, default=None,
124+
help="torch.compile mode (default: None).")
123125

124126
# codegen (model compilation) options
125127
scripting_group = parser.add_mutually_exclusive_group()
@@ -224,6 +226,7 @@ def __init__(
224226
device='cuda',
225227
torchscript=False,
226228
torchcompile=None,
229+
torchcompile_mode=None,
227230
aot_autograd=False,
228231
reparam=False,
229232
precision='float32',
@@ -278,7 +281,7 @@ def __init__(
278281
elif torchcompile:
279282
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
280283
torch._dynamo.reset()
281-
self.model = torch.compile(self.model, backend=torchcompile)
284+
self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode)
282285
self.compiled = True
283286
elif aot_autograd:
284287
assert has_functorch, "functorch is needed for --aot-autograd"

inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@
114114
parser.add_argument('--fuser', default='', type=str,
115115
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
116116
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
117+
parser.add_argument('--torchcompile-mode', type=str, default=None,
118+
help="torch.compile mode (default: None).")
117119

118120
scripting_group = parser.add_mutually_exclusive_group()
119121
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@@ -216,7 +218,7 @@ def main():
216218
elif args.torchcompile:
217219
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
218220
torch._dynamo.reset()
219-
model = torch.compile(model, backend=args.torchcompile)
221+
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
220222
elif args.aot_autograd:
221223
assert has_functorch, "functorch is needed for --aot-autograd"
222224
model = memory_efficient_fusion(model)

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@
161161
help='Head initialization scale')
162162
group.add_argument('--head-init-bias', default=None, type=float,
163163
help='Head initialization bias value')
164+
group.add_argument('--torchcompile-mode', type=str, default=None,
165+
help="torch.compile mode (default: None).")
164166

165167
# scripting / codegen
166168
scripting_group = group.add_mutually_exclusive_group()
@@ -627,7 +629,7 @@ def main():
627629
if args.torchcompile:
628630
# torch compile should be done after DDP
629631
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
630-
model = torch.compile(model, backend=args.torchcompile)
632+
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
631633

632634
# create the train and eval datasets
633635
if args.data and not args.data_dir:

validate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@
139139
parser.add_argument('--reparam', default=False, action='store_true',
140140
help='Reparameterize model')
141141
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
142-
142+
parser.add_argument('--torchcompile-mode', type=str, default=None,
143+
help="torch.compile mode (default: None).")
143144

144145
scripting_group = parser.add_mutually_exclusive_group()
145146
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@@ -246,7 +247,7 @@ def validate(args):
246247
elif args.torchcompile:
247248
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
248249
torch._dynamo.reset()
249-
model = torch.compile(model, backend=args.torchcompile)
250+
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
250251
elif args.aot_autograd:
251252
assert has_functorch, "functorch is needed for --aot-autograd"
252253
model = memory_efficient_fusion(model)

0 commit comments

Comments
 (0)