Skip to content

Commit 68d0605

Browse files
hsharma35YIWENX14
authored andcommitted
Fix Graph builder for higher order ops.
Differential Revision: D68231732 Pull Request resolved: #7684
1 parent 096d2db commit 68d0605

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

backends/cadence/aot/graph_builder.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from typing import Optional, Sequence, Union
77

88
import torch
9-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
10+
from torch._dispatch.python import enable_python_dispatcher
1011
from torch._subclasses import FakeTensor, FakeTensorMode
1112
from torch.fx.node import Argument, Target
1213
from torch.utils import _pytree as pytree
@@ -80,6 +81,22 @@ def call_operator(
8081
kwargs = {}
8182
return super().call_operator(op, args, kwargs, meta)
8283

84+
def call_submodule(
85+
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
86+
) -> PassResult:
87+
return ExportPass().call(graph_module)
88+
89+
def _fx(
90+
self,
91+
kind: str,
92+
target: torch.fx.node.Target,
93+
args: tuple[Argument, ...],
94+
kwargs: dict[str, Argument],
95+
meta: NodeMetadata,
96+
) -> ProxyValue:
97+
with self.fake_tensor_mode, enable_python_dispatcher():
98+
return super()._fx(kind, target, args, kwargs, meta)
99+
83100

84101
def single_op_builder(
85102
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],

backends/cadence/aot/tests/test_graph_builder.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
# pyre-strict
4+
5+
6+
from typing import Sequence
37

48
import executorch.backends.cadence.aot.ops_registrations # noqa
59
import torch
@@ -9,7 +13,7 @@
913
)
1014
from executorch.backends.cadence.aot.pass_utils import count_node
1115
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
16+
from executorch.exir.pass_base import ExportPass, NodeMetadata
1317
from later.unittest import TestCase
1418

1519

@@ -68,3 +72,30 @@ def test_graph_with_single_im2row(self) -> None:
6872
# Check graph has a single im2row node.
6973
self.assertEqual(len([gm.graph.nodes]), 1)
7074
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
75+
76+
77+
class TestHigherOrderOps(TestCase):
78+
def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule:
79+
builder = GraphBuilder()
80+
x = builder.placeholder("x", torch.randn(*x_shape))
81+
add = builder.call_operator(
82+
exir_ops.edge.aten.add.Tensor,
83+
(x, x), # pyre-ignore
84+
)
85+
builder.output([x, add])
86+
gm = builder.get_graph_module()
87+
# Check if graph module is valid by running exportpass on it.
88+
gm = ExportPass().call(gm).graph_module
89+
return gm
90+
91+
def test_call_map(self) -> None:
92+
builder = GraphBuilder()
93+
x_shape = (4, 8, 8)
94+
x = builder.placeholder("x", torch.randn(*x_shape))
95+
map_node = builder.call_map(
96+
self._get_inner_graph(x_shape[1:]), [x], [], NodeMetadata({})
97+
)
98+
builder.output([map_node])
99+
gm = builder.get_graph_module()
100+
# Check if graph module is valid by running exportpass on it.
101+
ExportPass().call(gm).graph_module

0 commit comments

Comments
 (0)