Skip to content

Commit 97cb892

Browse files
authored
add graph builder in oss for fuse ops
Differential Revision: D65911233 Pull Request resolved: #6877
1 parent fbc384c commit 97cb892

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,18 @@ python_library(
132132
],
133133
)
134134

135+
python_library(
136+
name = "graph_builder",
137+
srcs = [
138+
"graph_builder.py",
139+
],
140+
typing = True,
141+
deps = [
142+
"fbcode//caffe2:torch",
143+
"fbcode//executorch/exir:pass_base",
144+
],
145+
)
146+
135147
python_library(
136148
name = "fuse_ops",
137149
srcs = [

backends/cadence/aot/graph_builder.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import logging
6+
from typing import Optional, Sequence, Union
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
10+
from torch._subclasses import FakeTensor, FakeTensorMode
11+
from torch.fx.node import Argument, Target
12+
from torch.utils import _pytree as pytree
13+
14+
15+
class GraphBuilder(ExportPass):
16+
"""Utility class for creating a graph module with user-specified ops.
17+
18+
This class allows us to create test graph modules with any ops we want
19+
directly, rather than relying on decomposition or passes.
20+
21+
Usage:
22+
builder = GraphBuilder()
23+
# To insert placeholders, use builder.placeholder.
24+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
25+
# To insert an op, use builder.call_operator.
26+
op = builder.call_operator(
27+
some_op
28+
(x, other_args, ...),
29+
)
30+
# Insert outputs as a list of ProxyValues using builder.output.
31+
builder.output([op])
32+
# Get GraphModule from builder.
33+
gm = builder.get_graph_module()
34+
"""
35+
36+
def __init__(self) -> None:
37+
self.exporter = ExportPass()
38+
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
39+
self, torch.fx.graph.CodeGen()
40+
)
41+
self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
42+
self.tracer.fake_tensor_mode = self.fake_tensor_mode
43+
44+
# This will be called to create nodes in tracer.
45+
self.interpreter = torch.fx.Interpreter(
46+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
47+
)
48+
49+
# pyre-ignore[14]: Inconsistent override.
50+
def placeholder(
51+
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
52+
) -> ProxyValue:
53+
if not isinstance(fake_tensor, FakeTensor):
54+
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
55+
logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
56+
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
57+
return placeholder
58+
59+
# pyre-ignore[14]: Inconsistent override.
60+
def output(self, results: list[ProxyValue]) -> ProxyValue:
61+
logging.info(f"Creating outputs {results}")
62+
return super().output(results, NodeMetadata({}))
63+
64+
def get_graph_module(self) -> torch.fx.GraphModule:
65+
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
66+
67+
def call_operator(
68+
self,
69+
op, # pyre-ignore
70+
args: tuple[Argument, ...],
71+
kwargs: Optional[dict[str, Argument]] = None,
72+
meta: Optional[NodeMetadata] = None,
73+
) -> ProxyValue:
74+
if meta is None:
75+
meta = NodeMetadata({})
76+
if kwargs is None:
77+
kwargs = {}
78+
return super().call_operator(op, args, kwargs, meta)
79+
80+
81+
def single_op_builder(
82+
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
83+
op: Target,
84+
args: Sequence[Argument],
85+
kwargs: Optional[dict[str, Argument]] = None,
86+
) -> torch.fx.GraphModule:
87+
"""Create a graph module with a single op.
88+
89+
Args:
90+
placeholders: Placeholders to be used as inputs to the GraphModule.
91+
op: The op to be inserted.
92+
args: The args to be passed to the op.
93+
kwargs: The kwargs to be passed to the op.
94+
95+
Returns:
96+
A graph module with a single op
97+
"""
98+
builder = GraphBuilder()
99+
op_to_placeholder_dict = {
100+
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
101+
}
102+
proxy_args, proxy_kwargs = pytree.tree_map_only(
103+
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
104+
)
105+
node = builder.call_operator(op, proxy_args, proxy_kwargs)
106+
builder.output([node])
107+
return builder.get_graph_module()

0 commit comments

Comments
 (0)