Skip to content

FF jarvis #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions exir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

from executorch.exir.capture import (
_capture_legacy_do_not_use,
capture,
capture_multiple,
CaptureConfig,
Expand Down Expand Up @@ -38,6 +39,7 @@
"EmitterOutput",
"capture",
"capture_multiple",
"_capture_legacy_do_not_use",
"CallSpec",
"ExportedProgram",
"ExirExportedProgram",
Expand Down
8 changes: 7 additions & 1 deletion exir/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

# pyre-strict

from executorch.exir.capture._capture import capture, capture_multiple
from executorch.exir.capture._capture import (
_capture_legacy_do_not_use,
capture,
capture_multiple,
)

from executorch.exir.capture._config import (
CaptureConfig,
EdgeCompileConfig,
Expand All @@ -15,6 +20,7 @@

__all__ = [
"capture",
"_capture_legacy_do_not_use",
"capture_multiple",
"CaptureConfig",
"EdgeCompileConfig",
Expand Down
51 changes: 34 additions & 17 deletions exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from executorch.exir.capture._config import CaptureConfig
from executorch.exir.error import ExportError, ExportErrorType, InternalError
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
from executorch.exir.program._program import HackedUpExportedProgramDONOTUSE
from executorch.exir.tracer import (
_default_decomposition_table,
dispatch_trace,
Expand Down Expand Up @@ -43,6 +44,38 @@
)


@compatibility(is_backward_compatible=False)
def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
"""
This is a legacy API that should be avoided. Prefer to use capture() instead.
"""
warnings.warn(
"This function is now deprecated, please use `capture` instead. "
"See https://github.com/pytorch/functorch for more details.",
DeprecationWarning,
)

graph_module = dispatch_trace(f, args)
flat_args = tuple(pytree.tree_flatten(args)[0])
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec

_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
graph_module._apply(torch.Tensor.contiguous)

ep = HackedUpExportedProgramDONOTUSE(
graph_module,
graph_module.graph,
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
CallSpec(in_spec, out_spec),
{},
{},
[],
[],
None,
)
return ExirExportedProgram(ep, False)


@compatibility(is_backward_compatible=False)
def capture(
f: Callable[..., Any],
Expand Down Expand Up @@ -183,23 +216,7 @@ def convert_to_fake(x):
flatten_output(graph_module)

else:
warnings.warn(
"exir.capture with pt2_mode=False is deprecated. Please use the default () instead."
)
if not config.enable_functionalization:
raise InternalError(
"Can only disable functionalization under exir.capture() pt2 mode."
)
if config.enable_dynamic_shape:
raise InternalError(
"Can only enable dynamic shape tracing under exir.capture() pt2 mode."
)
if config.enable_aot:
raise InternalError(
"Using AOT mode is not supported for leagacy capture mode, please use instead."
)
graph_module = dispatch_trace(f, args)
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
raise InternalError("pt2=False path is officially deprecated")

_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
graph_module._apply(torch.Tensor.contiguous)
Expand Down
88 changes: 88 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,99 @@
EXIREdgeDialectVerifier,
)
from torch._export import ExportedProgram
from torch.fx import _pytree as fx_pytree
from torch.fx._compatibility import compatibility
from torch.utils import _pytree as pytree

Val = Any


class HackedUpExportedProgramDONOTUSE(ExportedProgram):
def __init__(
self,
root,
graph,
graph_signature,
call_spec,
state_dict,
range_constraints,
equality_constraints,
module_call_graph,
example_inputs,
):
super().__init__(
root,
graph,
graph_signature,
call_spec,
state_dict,
range_constraints,
equality_constraints,
module_call_graph,
example_inputs,
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
from torch._export import combine_args_kwargs

if self.call_spec.in_spec is not None:
user_args = combine_args_kwargs(args, kwargs)
try:
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
except Exception:
_, received_spec = pytree.tree_flatten(user_args)
raise error.InternalError(
"Trying to flatten user inputs with exported input tree spec: \n"
f"{self.call_spec.in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)

ordered_params = tuple(
self.state_dict[name] for name in self.graph_signature.parameters
)
ordered_buffers = tuple(
self.state_dict[name] for name in self.graph_signature.buffers
)

with torch.no_grad():
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
res = torch.fx.Interpreter(self.graph_module).run(
*ordered_params, *ordered_buffers, *args, enable_io_processing=False
)

if self.call_spec.out_spec is not None:
mutation = self.graph_signature.buffers_to_mutate
num_mutated = len(mutation)
mutated_buffers = res[:num_mutated]

# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
if assertion_dep_token is not None:
assertion_dep_token_index = list(assertion_dep_token.keys())[0]
res = res[:assertion_dep_token_index]

res = res[num_mutated:]
try:
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
except Exception:
_, received_spec = pytree.tree_flatten(res)
raise error.InternalError(
"Trying to flatten user outputs with exported output tree spec: \n"
f"{self.call_spec.out_spec}\n"
"but actually got outputs with tree spec of: \n"
f"{received_spec}"
)
finally:
ix = 0
for buffer in self.graph_signature.buffers_to_mutate.values():
self.state_dict[buffer] = mutated_buffers[ix]
ix += 1
return res


@compatibility(is_backward_compatible=False)
class ExirExportedProgram:
def __init__(
Expand Down