Skip to content

Commit 37f1555

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
FF jarvis (#178)
Summary: Tiutle Reviewed By: mcremon-meta Differential Revision: D48838464
1 parent d93a7d0 commit 37f1555

File tree

4 files changed

+131
-18
lines changed

4 files changed

+131
-18
lines changed

exir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from executorch.exir.capture import (
10+
_capture_legacy_do_not_use,
1011
capture,
1112
capture_multiple,
1213
CaptureConfig,
@@ -38,6 +39,7 @@
3839
"EmitterOutput",
3940
"capture",
4041
"capture_multiple",
42+
"_capture_legacy_do_not_use",
4143
"CallSpec",
4244
"ExportedProgram",
4345
"ExirExportedProgram",

exir/capture/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
# pyre-strict
88

9-
from executorch.exir.capture._capture import capture, capture_multiple
9+
from executorch.exir.capture._capture import (
10+
_capture_legacy_do_not_use,
11+
capture,
12+
capture_multiple,
13+
)
14+
1015
from executorch.exir.capture._config import (
1116
CaptureConfig,
1217
EdgeCompileConfig,
@@ -15,6 +20,7 @@
1520

1621
__all__ = [
1722
"capture",
23+
"_capture_legacy_do_not_use",
1824
"capture_multiple",
1925
"CaptureConfig",
2026
"EdgeCompileConfig",

exir/capture/_capture.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from executorch.exir.capture._config import CaptureConfig
1616
from executorch.exir.error import ExportError, ExportErrorType, InternalError
1717
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
18+
from executorch.exir.program._program import HackedUpExportedProgramDONOTUSE
1819
from executorch.exir.tracer import (
1920
_default_decomposition_table,
2021
dispatch_trace,
@@ -43,6 +44,38 @@
4344
)
4445

4546

47+
@compatibility(is_backward_compatible=False)
48+
def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
49+
"""
50+
This is a legacy API that should be avoided. Prefer to use capture() instead.
51+
"""
52+
warnings.warn(
53+
"This function is now deprecated, please use `capture` instead. "
54+
"See https://github.com/pytorch/functorch for more details.",
55+
DeprecationWarning,
56+
)
57+
58+
graph_module = dispatch_trace(f, args)
59+
flat_args = tuple(pytree.tree_flatten(args)[0])
60+
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
61+
62+
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
63+
graph_module._apply(torch.Tensor.contiguous)
64+
65+
ep = HackedUpExportedProgramDONOTUSE(
66+
graph_module,
67+
graph_module.graph,
68+
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
69+
CallSpec(in_spec, out_spec),
70+
{},
71+
{},
72+
[],
73+
[],
74+
None,
75+
)
76+
return ExirExportedProgram(ep, False)
77+
78+
4679
@compatibility(is_backward_compatible=False)
4780
def capture(
4881
f: Callable[..., Any],
@@ -183,23 +216,7 @@ def convert_to_fake(x):
183216
flatten_output(graph_module)
184217

185218
else:
186-
warnings.warn(
187-
"exir.capture with pt2_mode=False is deprecated. Please use the default () instead."
188-
)
189-
if not config.enable_functionalization:
190-
raise InternalError(
191-
"Can only disable functionalization under exir.capture() pt2 mode."
192-
)
193-
if config.enable_dynamic_shape:
194-
raise InternalError(
195-
"Can only enable dynamic shape tracing under exir.capture() pt2 mode."
196-
)
197-
if config.enable_aot:
198-
raise InternalError(
199-
"Using AOT mode is not supported for leagacy capture mode, please use instead."
200-
)
201-
graph_module = dispatch_trace(f, args)
202-
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
219+
raise InternalError("pt2=False path is officially deprecated")
203220

204221
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
205222
graph_module._apply(torch.Tensor.contiguous)

exir/program/_program.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,99 @@
2929
EXIREdgeDialectVerifier,
3030
)
3131
from torch._export import ExportedProgram
32+
from torch.fx import _pytree as fx_pytree
3233
from torch.fx._compatibility import compatibility
34+
from torch.utils import _pytree as pytree
3335

3436
Val = Any
3537

3638

39+
class HackedUpExportedProgramDONOTUSE(ExportedProgram):
40+
def __init__(
41+
self,
42+
root,
43+
graph,
44+
graph_signature,
45+
call_spec,
46+
state_dict,
47+
range_constraints,
48+
equality_constraints,
49+
module_call_graph,
50+
example_inputs,
51+
):
52+
super().__init__(
53+
root,
54+
graph,
55+
graph_signature,
56+
call_spec,
57+
state_dict,
58+
range_constraints,
59+
equality_constraints,
60+
module_call_graph,
61+
example_inputs,
62+
)
63+
64+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
65+
import torch._export.error as error
66+
from torch._export import combine_args_kwargs
67+
68+
if self.call_spec.in_spec is not None:
69+
user_args = combine_args_kwargs(args, kwargs)
70+
try:
71+
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
72+
except Exception:
73+
_, received_spec = pytree.tree_flatten(user_args)
74+
raise error.InternalError(
75+
"Trying to flatten user inputs with exported input tree spec: \n"
76+
f"{self.call_spec.in_spec}\n"
77+
"but actually got inputs with tree spec of: \n"
78+
f"{received_spec}"
79+
)
80+
81+
ordered_params = tuple(
82+
self.state_dict[name] for name in self.graph_signature.parameters
83+
)
84+
ordered_buffers = tuple(
85+
self.state_dict[name] for name in self.graph_signature.buffers
86+
)
87+
88+
with torch.no_grad():
89+
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
90+
# See: torch/_functorch/aot_autograd.py#L1034
91+
res = torch.fx.Interpreter(self.graph_module).run(
92+
*ordered_params, *ordered_buffers, *args, enable_io_processing=False
93+
)
94+
95+
if self.call_spec.out_spec is not None:
96+
mutation = self.graph_signature.buffers_to_mutate
97+
num_mutated = len(mutation)
98+
mutated_buffers = res[:num_mutated]
99+
100+
# Exclude dependency token from final result.
101+
assertion_dep_token = self.graph_signature.assertion_dep_token
102+
if assertion_dep_token is not None:
103+
assertion_dep_token_index = list(assertion_dep_token.keys())[0]
104+
res = res[:assertion_dep_token_index]
105+
106+
res = res[num_mutated:]
107+
try:
108+
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
109+
except Exception:
110+
_, received_spec = pytree.tree_flatten(res)
111+
raise error.InternalError(
112+
"Trying to flatten user outputs with exported output tree spec: \n"
113+
f"{self.call_spec.out_spec}\n"
114+
"but actually got outputs with tree spec of: \n"
115+
f"{received_spec}"
116+
)
117+
finally:
118+
ix = 0
119+
for buffer in self.graph_signature.buffers_to_mutate.values():
120+
self.state_dict[buffer] = mutated_buffers[ix]
121+
ix += 1
122+
return res
123+
124+
37125
@compatibility(is_backward_compatible=False)
38126
class ExirExportedProgram:
39127
def __init__(

0 commit comments

Comments
 (0)