Skip to content

Commit f571cd7

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
emitter handles lifted graphs
Summary: just needed some extra logic for placeholder nodes. Later can add some assertions in getattr node once lifted is all that exists Reviewed By: mcr229 Differential Revision: D47888934 fbshipit-source-id: 361019d4e59fc60e146950d98406132ffed00f83
1 parent 6cef6c7 commit f571cd7

File tree

4 files changed

+73
-6
lines changed

4 files changed

+73
-6
lines changed

exir/emit/_emit_program.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ def emit_program(
149149

150150
# emit each entry point in order according to name.
151151
for name, exported_program in sorted(methods.items()):
152+
if (
153+
exported_program.graph_signature.buffers_to_mutate
154+
): # see if we are mutating any state
155+
raise ExportError(
156+
ExportErrorType.INVALID_INPUT_TYPE,
157+
"Buffers cannot be modified in executorch.",
158+
)
152159
# create empty state
153160
emitter_state = _EmitterState(
154161
values=[],

exir/emit/_emitter.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import executorch.extension.pytree as ex_pytree
4040
import torch
4141
import torch.fx
42-
from executorch.exir import delegate
4342
from executorch.exir.common import add_cursor_to_graph
4443
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
4544
from executorch.exir.dialects.backend._ops import BackendOpOverload
@@ -1270,6 +1269,7 @@ def __init__(
12701269
) -> None:
12711270
super().__init__(exported_program.graph_module, emitter_state, program_state)
12721271
self.name = name
1272+
self.exported_program = exported_program
12731273

12741274
self.inputs: List[int] = []
12751275
self.outputs: List[int] = []
@@ -1301,13 +1301,30 @@ def placeholder(
13011301
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
13021302
"""
13031303
spec = self.node.meta["spec"]
1304+
const_tensor = False
1305+
if isinstance(target, str) and (
1306+
target in self.exported_program.graph_signature.inputs_to_parameters
1307+
or target in self.exported_program.graph_signature.inputs_to_buffers
1308+
):
1309+
1310+
fqn = (
1311+
self.exported_program.graph_signature.inputs_to_parameters[target]
1312+
if target in self.exported_program.graph_signature.inputs_to_parameters
1313+
else self.exported_program.graph_signature.inputs_to_buffers[target]
1314+
)
1315+
spec = TensorSpec.from_tensor(
1316+
self.exported_program.state_dict[fqn], const=True
1317+
)
1318+
const_tensor = True
13041319
evalue = (
13051320
self._tensor_spec_to_evalue(spec)
13061321
if isinstance(spec, TensorSpec)
13071322
else self._constant_to_evalue(spec, None)
13081323
)
13091324
value = self._emit_evalue(evalue)
1310-
self.inputs.append(value.id)
1325+
if not const_tensor:
1326+
self.inputs.append(value.id)
1327+
13111328
return value
13121329

13131330
def output(

exir/emit/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ python_unittest(
1818
"//executorch/exir/passes:const_prop_pass",
1919
"//executorch/exir/tests:lib",
2020
"//executorch/exir/tests:models",
21+
"//executorch/extension/pybindings:portable", # @manual
2122
],
2223
)

exir/emit/test/test_emit.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import executorch.exir.schema as schema
1616
import executorch.exir.tests.models as models
1717
import torch
18-
from executorch.exir import CaptureConfig, ExecutorchProgram
18+
from executorch.exir import CaptureConfig, EdgeCompileConfig, ExecutorchProgram
1919
from executorch.exir.emit import emit_program
2020
from executorch.exir.error import InternalError
2121
from executorch.exir.passes.const_prop_pass import ConstPropPass
22-
from executorch.exir.print_program import print_program # noqa
22+
from executorch.exir.print_program import pretty_print, print_program # noqa
2323
from executorch.exir.schema import (
2424
Bool,
2525
EValue,
@@ -36,10 +36,19 @@
3636
Tensor,
3737
)
3838
from executorch.exir.tests.common import register_additional_test_aten_ops
39-
from executorch.exir.tests.models import Mul
39+
from executorch.exir.tests.models import (
40+
Emformer,
41+
FeedForwardBlock,
42+
MLP,
43+
Mul,
44+
ScaledDotProductAttention,
45+
ScaledDotProductAttentionModularized,
46+
)
4047
from executorch.exir.tracer import ExirDynamoConfig
48+
from executorch.extension.pybindings.portable import ( # pyre-ignore
49+
_load_for_executorch_from_buffer,
50+
)
4151
from functorch.experimental import control_flow
42-
from torch.fx.passes.infra.pass_base import PassResult
4352

4453

4554
class TestEmit(unittest.TestCase):
@@ -555,6 +564,39 @@ def f(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
555564
len(program.execution_plan[0].chains[0].instructions[0].instr_args.args), 4
556565
)
557566

567+
def test_lifted(self) -> None:
568+
def test_model(eager_module):
569+
inputs = eager_module.get_random_inputs()
570+
eager_output = eager_module.forward(*inputs)
571+
capture_config = exir.CaptureConfig(
572+
pt2_mode=True,
573+
enable_functionalization=True,
574+
enable_dynamic_shape=True,
575+
enable_aot=True,
576+
_unlift=False,
577+
)
578+
579+
aten_dialect = exir.capture(
580+
eager_module,
581+
eager_module.get_random_inputs(),
582+
capture_config,
583+
)
584+
585+
edge_dialect = aten_dialect.to_edge()
586+
587+
executorch_dialect = edge_dialect.to_executorch()
588+
589+
pretty_print(executorch_dialect.program)
590+
591+
executorch_module = _load_for_executorch_from_buffer(
592+
executorch_dialect.buffer
593+
)
594+
et_output = executorch_module.forward(inputs)
595+
self.assertTrue(torch.allclose(eager_output, et_output[0], atol=1e-04))
596+
597+
test_model(MLP())
598+
# test_model(Emformer()) cannot run without bernoulli.out being added
599+
558600
def test_emit_multiple_out(self) -> None:
559601
def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
560602
return torch.topk(x, 5)

0 commit comments

Comments
 (0)