Skip to content

split to_cadence_edge_executorch API to to_cadence and to_executorch_gen_etrecord #6880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ python_library(
],
)

python_library(
name = "graph_builder",
srcs = [
"graph_builder.py",
],
typing = True,
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/exir:pass_base",
],
)

python_library(
name = "fuse_ops",
srcs = [
Expand All @@ -150,3 +162,20 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
],
)

python_unittest(
name = "test_graph_builder",
srcs = [
"tests/test_graph_builder.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//later:lib",
":ops_registrations"
],
)
21 changes: 20 additions & 1 deletion backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,26 @@ def export_to_edge(
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def export_to_cadence_edge_executorch(
def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
output_dir: Optional[str] = None,
opt_level: int = 1,
) -> EdgeProgramManager:
edge_prog_manager = export_to_edge(model, inputs)
cadence_passes = get_cadence_passes(opt_level)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
return cadence_prog_manager


def export_to_executorch_gen_etrecord(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from executorch.backends.cadence.aot.compiler import (
convert_pt2,
export_to_cadence_edge_executorch,
export_to_executorch_gen_etrecord,
fuse_pt2,
)

Expand Down Expand Up @@ -86,8 +86,8 @@ def export_model(
quantized_model = fuse_pt2(converted_model, quantizer)

# Get edge program after Cadence specific passes
exec_prog: ExecutorchProgramManager = export_to_cadence_edge_executorch(
quantized_model, example_inputs, working_dir
exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord(
quantized_model, example_inputs, output_dir=working_dir
)

logging.info("Final exported graph:\n")
Expand Down
107 changes: 107 additions & 0 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

import logging
from typing import Optional, Sequence, Union

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Argument, Target
from torch.utils import _pytree as pytree


class GraphBuilder(ExportPass):
"""Utility class for creating a graph module with user-specified ops.

This class allows us to create test graph modules with any ops we want
directly, rather than relying on decomposition or passes.

Usage:
builder = GraphBuilder()
# To insert placeholders, use builder.placeholder.
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
# To insert an op, use builder.call_operator.
op = builder.call_operator(
some_op
(x, other_args, ...),
)
# Insert outputs as a list of ProxyValues using builder.output.
builder.output([op])
# Get GraphModule from builder.
gm = builder.get_graph_module()
"""

def __init__(self) -> None:
self.exporter = ExportPass()
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
self, torch.fx.graph.CodeGen()
)
self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
self.tracer.fake_tensor_mode = self.fake_tensor_mode

# This will be called to create nodes in tracer.
self.interpreter = torch.fx.Interpreter(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)

# pyre-ignore[14]: Inconsistent override.
def placeholder(
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
) -> ProxyValue:
if not isinstance(fake_tensor, FakeTensor):
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
return placeholder

# pyre-ignore[14]: Inconsistent override.
def output(self, results: list[ProxyValue]) -> ProxyValue:
logging.info(f"Creating outputs {results}")
return super().output(results, NodeMetadata({}))

def get_graph_module(self) -> torch.fx.GraphModule:
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: Optional[dict[str, Argument]] = None,
meta: Optional[NodeMetadata] = None,
) -> ProxyValue:
if meta is None:
meta = NodeMetadata({})
if kwargs is None:
kwargs = {}
return super().call_operator(op, args, kwargs, meta)


def single_op_builder(
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
op: Target,
args: Sequence[Argument],
kwargs: Optional[dict[str, Argument]] = None,
) -> torch.fx.GraphModule:
"""Create a graph module with a single op.

Args:
placeholders: Placeholders to be used as inputs to the GraphModule.
op: The op to be inserted.
args: The args to be passed to the op.
kwargs: The kwargs to be passed to the op.

Returns:
A graph module with a single op
"""
builder = GraphBuilder()
op_to_placeholder_dict = {
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
}
proxy_args, proxy_kwargs = pytree.tree_map_only(
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
)
node = builder.call_operator(op, proxy_args, proxy_kwargs)
builder.output([node])
return builder.get_graph_module()
9 changes: 9 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
continue
graph_nodes.append(node.name)
return graph_nodes


def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
"""Count the number of nodes with target `target` in the graph."""
total = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == target:
total += 1
return total
70 changes: 70 additions & 0 deletions backends/cadence/aot/tests/test_graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.


import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
single_op_builder,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from later.unittest import TestCase


class TestGraphBuilder(TestCase):
def test_graph_with_single_im2row(self) -> None:
# Create a graph with a single im2row node.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
pad_value = builder.placeholder("pad", torch.randn(1))
channels_last = False
im2row = builder.call_operator(
exir_ops.edge.cadence.im2row.default,
# pyre-ignore
(
x,
(2, 2),
(1, 1),
(0, 0),
(1, 1),
pad_value,
channels_last,
),
)
builder.output([im2row])
gm = builder.get_graph_module()
# Check if graph module is valid by running exportpass on it.
gm = ExportPass().call(gm).graph_module

# Check graph has a single im2row node.
self.assertEqual(len([gm.graph.nodes]), 1)
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)


class TestSingleOpBuilderUtility(TestCase):
def test_graph_with_single_im2row(self) -> None:
# Create a graph with a single im2row node.
x = torch.randn(1, 3, 224, 224)
pad_value = torch.randn(1)
channels_last = False
gm = single_op_builder(
(x, pad_value),
exir_ops.edge.cadence.im2row.default,
(
x,
(2, 2),
(1, 1),
(0, 0),
(1, 1),
pad_value,
channels_last,
),
)
# Check if graph module is valid by running exportpass on it.
gm = ExportPass().call(gm).graph_module

# Check graph has a single im2row node.
self.assertEqual(len([gm.graph.nodes]), 1)
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
Loading