Skip to content

Commit 884c5c8

Browse files
janselpytorchmergebot
authored andcommitted
Pass torch.compile mode/options to all backends (pytorch#99645)
Pull Request resolved: pytorch#99645 Approved by: https://github.com/anijain2305
1 parent 7295ab6 commit 884c5c8

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

test/inductor/test_config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,40 @@ def test_invalid_backend(self):
190190
lambda: torch.compile(dummy_fn, backend="does_not_exist")(torch.randn(10)),
191191
)
192192

193+
def test_non_inductor_backend(self):
194+
def assert_options(expected_mode=None, expected_options=None):
195+
def backend(gm, _, *, mode=None, options=None):
196+
nonlocal call_count
197+
self.assertEqual(mode, expected_mode)
198+
self.assertEqual(options, expected_options)
199+
call_count += 1
200+
return gm
201+
202+
return backend
203+
204+
inp = torch.randn(8)
205+
206+
def fn(x):
207+
return x + 1
208+
209+
for mode, options in [
210+
(None, None),
211+
("fast-mode", None),
212+
(None, {"foo": "bar"}),
213+
]:
214+
call_count = 0
215+
torch.compile(
216+
fn, backend=assert_options(mode, options), mode=mode, options=options
217+
)(inp)
218+
torch._dynamo.reset()
219+
self.assertEqual(call_count, 1)
220+
221+
# TypeError: eager() got an unexpected keyword argument 'mode'
222+
self.assertRaises(
223+
torch._dynamo.exc.BackendCompilerFailed,
224+
lambda: torch.compile(fn, backend="eager", mode="nope")(inp),
225+
)
226+
193227

194228
if __name__ == "__main__":
195229
run_tests()

torch/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,35 @@ def reset(self):
15361536
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
15371537
reset_cudagraph_trees()
15381538

1539+
class _TorchCompileWrapper:
1540+
def __init__(self, backend, mode, options, dynamic):
1541+
from torch._dynamo.backends.registry import lookup_backend
1542+
1543+
if isinstance(backend, str):
1544+
self.compiler_name = backend
1545+
elif hasattr(backend, "__name__"):
1546+
self.compiler_name = backend.__name__
1547+
else:
1548+
self.compiler_name = str(backend)
1549+
self.dynamic = dynamic
1550+
self.compiler_fn = lookup_backend(backend)
1551+
self.kwargs = {}
1552+
# only pass the args if they non-empty
1553+
if mode and mode != "default":
1554+
self.kwargs["mode"] = mode
1555+
if options:
1556+
self.kwargs["options"] = options
1557+
1558+
def __eq__(self, other):
1559+
return (isinstance(other, _TorchCompileWrapper) and
1560+
self.compiler_fn == other.compiler_fn and
1561+
self.kwargs == other.kwargs and
1562+
self.dynamic == other.dynamic)
1563+
1564+
def __call__(self, model_, inputs_):
1565+
return self.compiler_fn(model_, inputs_, **self.kwargs)
1566+
1567+
15391568
def compile(model: Optional[Callable] = None, *,
15401569
fullgraph: builtins.bool = False,
15411570
dynamic: builtins.bool = False,
@@ -1600,6 +1629,8 @@ def fn(model: Callable):
16001629
mode = "default"
16011630
if backend == "inductor":
16021631
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
1632+
else:
1633+
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
16031634

16041635
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
16051636

torch/_dynamo/repro/after_dynamo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def add_paths(exc):
111111
return compiled_gm
112112

113113
debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
114-
114+
if hasattr(unconfigured_compiler_fn, "compiler_name"):
115+
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name
115116
return debug_wrapper
116117

117118

0 commit comments

Comments
 (0)