Skip to content

Commit d2e6750

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
FF jarvis (#178)
Summary: Pull Request resolved: #178 Tiutle Edit: commandeered the diff from tugsbayasgalan to land it while he's on PTO and unblock the Jarvis CI. Reviewed By: JacobSzwejbka Differential Revision: D48838464 fbshipit-source-id: f0cacd8cf4f1081c23a922775866159ad7e8e703
1 parent bba84db commit d2e6750

File tree

5 files changed

+143
-18
lines changed

5 files changed

+143
-18
lines changed

exir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from executorch.exir.capture import (
10+
_capture_legacy_do_not_use,
1011
capture,
1112
capture_multiple,
1213
CaptureConfig,
@@ -38,6 +39,7 @@
3839
"EmitterOutput",
3940
"capture",
4041
"capture_multiple",
42+
"_capture_legacy_do_not_use",
4143
"CallSpec",
4244
"ExportedProgram",
4345
"ExirExportedProgram",

exir/capture/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
"//executorch/exir:error",
2323
"//executorch/exir:tracer",
2424
"//executorch/exir/program:lib",
25+
"//executorch/exir/program:program",
2526
],
2627
)
2728

exir/capture/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
# pyre-strict
88

9-
from executorch.exir.capture._capture import capture, capture_multiple
9+
from executorch.exir.capture._capture import (
10+
_capture_legacy_do_not_use,
11+
capture,
12+
capture_multiple,
13+
)
14+
1015
from executorch.exir.capture._config import (
1116
CaptureConfig,
1217
EdgeCompileConfig,
@@ -15,6 +20,7 @@
1520

1621
__all__ = [
1722
"capture",
23+
"_capture_legacy_do_not_use",
1824
"capture_multiple",
1925
"CaptureConfig",
2026
"EdgeCompileConfig",

exir/capture/_capture.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from executorch.exir.capture._config import CaptureConfig
1616
from executorch.exir.error import ExportError, ExportErrorType, InternalError
1717
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
18+
from executorch.exir.program._program import (
19+
HackedUpExportedProgramDONOTUSE,
20+
transform_exported_program,
21+
)
1822
from executorch.exir.tracer import (
1923
_default_decomposition_table,
2024
dispatch_trace,
@@ -43,6 +47,38 @@
4347
)
4448

4549

50+
@compatibility(is_backward_compatible=False)
51+
def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
52+
"""
53+
This is a legacy API that should be avoided. Prefer to use capture() instead.
54+
"""
55+
warnings.warn(
56+
"This function is now deprecated, please use `capture` instead. "
57+
"See https://github.com/pytorch/functorch for more details.",
58+
DeprecationWarning,
59+
)
60+
61+
graph_module = dispatch_trace(f, args)
62+
flat_args = tuple(pytree.tree_flatten(args)[0])
63+
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
64+
65+
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
66+
graph_module._apply(torch.Tensor.contiguous)
67+
68+
ep = HackedUpExportedProgramDONOTUSE(
69+
graph_module,
70+
graph_module.graph,
71+
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
72+
CallSpec(in_spec, out_spec),
73+
{},
74+
{},
75+
[],
76+
[],
77+
None,
78+
)
79+
return ExirExportedProgram(ep, False)
80+
81+
4682
@compatibility(is_backward_compatible=False)
4783
def capture(
4884
f: Callable[..., Any],
@@ -183,23 +219,7 @@ def convert_to_fake(x):
183219
flatten_output(graph_module)
184220

185221
else:
186-
warnings.warn(
187-
"exir.capture with pt2_mode=False is deprecated. Please use the default () instead."
188-
)
189-
if not config.enable_functionalization:
190-
raise InternalError(
191-
"Can only disable functionalization under exir.capture() pt2 mode."
192-
)
193-
if config.enable_dynamic_shape:
194-
raise InternalError(
195-
"Can only enable dynamic shape tracing under exir.capture() pt2 mode."
196-
)
197-
if config.enable_aot:
198-
raise InternalError(
199-
"Using AOT mode is not supported for leagacy capture mode, please use instead."
200-
)
201-
graph_module = dispatch_trace(f, args)
202-
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
222+
raise InternalError("pt2=False path is officially deprecated")
203223

204224
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
205225
graph_module._apply(torch.Tensor.contiguous)

exir/program/_program.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,107 @@
2929
EXIREdgeDialectVerifier,
3030
)
3131
from torch._export import ExportedProgram
32+
from torch.fx import _pytree as fx_pytree
3233
from torch.fx._compatibility import compatibility
34+
from torch.utils import _pytree as pytree
3335

3436
Val = Any
3537

3638

39+
# Stub to ease migration from `transform` to private `_transform`
40+
def transform_exported_program(ep, *passes: PassType) -> ExportedProgram:
41+
if hasattr(ep, "_transform"):
42+
return ep._transform(*passes)
43+
else:
44+
return ep.transform(*passes)
45+
46+
47+
class HackedUpExportedProgramDONOTUSE(ExportedProgram):
48+
def __init__(
49+
self,
50+
root,
51+
graph,
52+
graph_signature,
53+
call_spec,
54+
state_dict,
55+
range_constraints,
56+
equality_constraints,
57+
module_call_graph,
58+
example_inputs,
59+
):
60+
super().__init__(
61+
root,
62+
graph,
63+
graph_signature,
64+
call_spec,
65+
state_dict,
66+
range_constraints,
67+
equality_constraints,
68+
module_call_graph,
69+
example_inputs,
70+
)
71+
72+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
73+
import torch._export.error as error
74+
from torch._export import combine_args_kwargs
75+
76+
if self.call_spec.in_spec is not None:
77+
user_args = combine_args_kwargs(args, kwargs)
78+
try:
79+
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
80+
except Exception:
81+
_, received_spec = pytree.tree_flatten(user_args)
82+
raise error.InternalError(
83+
"Trying to flatten user inputs with exported input tree spec: \n"
84+
f"{self.call_spec.in_spec}\n"
85+
"but actually got inputs with tree spec of: \n"
86+
f"{received_spec}"
87+
)
88+
89+
ordered_params = tuple(
90+
self.state_dict[name] for name in self.graph_signature.parameters
91+
)
92+
ordered_buffers = tuple(
93+
self.state_dict[name] for name in self.graph_signature.buffers
94+
)
95+
96+
with torch.no_grad():
97+
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
98+
# See: torch/_functorch/aot_autograd.py#L1034
99+
res = torch.fx.Interpreter(self.graph_module).run(
100+
*ordered_params, *ordered_buffers, *args, enable_io_processing=False
101+
)
102+
103+
if self.call_spec.out_spec is not None:
104+
mutation = self.graph_signature.buffers_to_mutate
105+
num_mutated = len(mutation)
106+
mutated_buffers = res[:num_mutated]
107+
108+
# Exclude dependency token from final result.
109+
assertion_dep_token = self.graph_signature.assertion_dep_token
110+
if assertion_dep_token is not None:
111+
assertion_dep_token_index = list(assertion_dep_token.keys())[0]
112+
res = res[:assertion_dep_token_index]
113+
114+
res = res[num_mutated:]
115+
try:
116+
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
117+
except Exception:
118+
_, received_spec = pytree.tree_flatten(res)
119+
raise error.InternalError(
120+
"Trying to flatten user outputs with exported output tree spec: \n"
121+
f"{self.call_spec.out_spec}\n"
122+
"but actually got outputs with tree spec of: \n"
123+
f"{received_spec}"
124+
)
125+
finally:
126+
ix = 0
127+
for buffer in self.graph_signature.buffers_to_mutate.values():
128+
self.state_dict[buffer] = mutated_buffers[ix]
129+
ix += 1
130+
return res
131+
132+
37133
@compatibility(is_backward_compatible=False)
38134
class ExirExportedProgram:
39135
def __init__(

0 commit comments

Comments
 (0)