Skip to content

Commit b554b41

Browse files
hsharma35facebook-github-bot
authored andcommitted
Fix Graph builder for higher order ops.
Summary: All graph builder to create higher order ops like call_map. Differential Revision: D68231732
1 parent a727b55 commit b554b41

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

backends/cadence/aot/graph_builder.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from typing import Optional, Sequence, Union
77

88
import torch
9-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
10+
from torch._dispatch.python import enable_python_dispatcher
1011
from torch._subclasses import FakeTensor, FakeTensorMode
1112
from torch.fx.node import Argument, Target
1213
from torch.utils import _pytree as pytree
@@ -80,6 +81,22 @@ def call_operator(
8081
kwargs = {}
8182
return super().call_operator(op, args, kwargs, meta)
8283

84+
def call_submodule(
85+
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
86+
) -> PassResult:
87+
return ExportPass().call(graph_module)
88+
89+
def _fx(
90+
self,
91+
kind: str,
92+
target: torch.fx.node.Target,
93+
args: tuple[Argument, ...],
94+
kwargs: dict[str, Argument],
95+
meta: NodeMetadata,
96+
) -> ProxyValue:
97+
with self.fake_tensor_mode, enable_python_dispatcher():
98+
return super()._fx(kind, target, args, kwargs, meta)
99+
83100

84101
def single_op_builder(
85102
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],

backends/cadence/aot/tests/test_graph_builder.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

33

4+
from typing import Sequence
5+
46
import executorch.backends.cadence.aot.ops_registrations # noqa
57
import torch
68
from executorch.backends.cadence.aot.graph_builder import (
@@ -9,7 +11,7 @@
911
)
1012
from executorch.backends.cadence.aot.pass_utils import count_node
1113
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
14+
from executorch.exir.pass_base import ExportPass, NodeMetadata
1315
from later.unittest import TestCase
1416

1517

@@ -68,3 +70,27 @@ def test_graph_with_single_im2row(self) -> None:
6870
# Check graph has a single im2row node.
6971
self.assertEqual(len([gm.graph.nodes]), 1)
7072
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
73+
74+
class TestHigherOrderOps(TestCase):
75+
def _get_inner_graph(self, x_shape: Sequence[int]):
76+
builder = GraphBuilder()
77+
x = builder.placeholder("x", torch.randn(*x_shape))
78+
add = builder.call_operator(
79+
exir_ops.edge.aten.add.Tensor,
80+
(x, x), # pyre-ignore
81+
)
82+
builder.output([x, add])
83+
gm = builder.get_graph_module()
84+
# Check if graph module is valid by running exportpass on it.
85+
gm = ExportPass().call(gm).graph_module
86+
return gm
87+
88+
def test_call_map(self):
89+
builder = GraphBuilder()
90+
x_shape = (4, 8, 8)
91+
x = builder.placeholder("x", torch.randn(*x_shape))
92+
map_node = builder.call_map(self._get_inner_graph(x_shape[1:]), [x], [], NodeMetadata({}))
93+
builder.output([map_node])
94+
gm = builder.get_graph_module()
95+
# Check if graph module is valid by running exportpass on it.
96+
ExportPass().call(gm).graph_module

0 commit comments

Comments
 (0)