Skip to content

Commit 2f21fe6

Browse files
cccclaifacebook-github-bot
authored andcommitted
example quantizer and delegate backend
Summary: An example to show how to add a new backend. We've demoed this and maybe it's easier to put up a generic implementation so it's easier get start on adding quantizer and delegate. This demo includes - example quantizer - example partitioner - example backend - examples operators available on the backend - passes on memory format permutates Reviewed By: tarun292 Differential Revision: D49120351 fbshipit-source-id: 1118820b365af9b2cb1eec62477a915e6bb90f6d
1 parent 89f3e39 commit 2f21fe6

23 files changed

+1094
-0
lines changed

backends/example/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
This folder is an exmample backend to lower MobileNetV2. It covers the AOT side and showcase how to quantize and lower a MobileNetV2 to the example backend. The serialization format is purely string for demo purpose as it'll be up to backend's decision to choose serialization format.
2+
3+
The folder structure incluces:
4+
- example_quantizer
5+
- example_partitioner
6+
- example_backend
7+
- examples_operators. Assuming all of them can run in the example backend.
8+
- The OpBase defined in op_base.py is just the draft idea, it can be defined more comprehensively depending tosa operator definitions
9+
- example_backend_delegate_passes. It includes passes that might be helpful in the backend. Right now there are two passes: merge_to_dim_pass.py and permute_memory_formats_pass.py. They are examples to show how to represent memory format permutation and how to represent operators with different memory format (like channel last)
10+
- merge_to_dim_pass.py only handles one merging cases. More cases need to be covered but should be straitforward.
11+
12+
## High Level Flow
13+
14+
In the following diagram, we show how to quantize a mobile net v2 model and lower it to ExampleBackend.
15+
16+
### Quantize and Delegate
17+
18+
We can define patterns based on the operators supported by the backend, which will be used by the quantizer and delegate.
19+
20+
![](./diagrams/quantize_delegate.png)
21+
22+
### Partitioner and Backend
23+
24+
The way partitioner and backend is, partitioner will tag the nodes to lower to the backend and backend will will receive all tagged nodes and preprocess them as a delegate.
25+
26+
![](./diagrams/delegate.png)
27+
28+
### Memory format permute
29+
30+
Some operators may have better performance in the memory format other than contiguous. One way to do that is to insert `to_dim_op` to describe memory format permutation and merge if there two opposite one next to each other.
31+
32+
![](./diagrams/memory_permute.png)

backends/example/TARGETS

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
4+
python_library(
5+
name = "example_quantizer",
6+
srcs = [
7+
"example_quantizer.py",
8+
],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/backends/example/example_operators:example_operators_lib",
12+
],
13+
)
14+
15+
python_library(
16+
name = "example_backend",
17+
srcs = [
18+
"example_backend.py",
19+
],
20+
deps = [
21+
"//executorch/backends/example/example_backend_delegate_passes:lib",
22+
"//executorch/exir/backend:backend_details",
23+
"//executorch/exir/backend:compile_spec_schema",
24+
],
25+
)
26+
27+
python_library(
28+
name = "example_partitioner",
29+
srcs = [
30+
"example_partitioner.py",
31+
],
32+
deps = [
33+
":example_backend",
34+
"//caffe2:torch",
35+
"//executorch/backends/example/example_operators:example_operators_lib",
36+
"//executorch/exir:graph_module",
37+
"//executorch/exir/backend:partitioner",
38+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
39+
"//executorch/exir/dialects:lib",
40+
],
41+
)
42+
43+
python_unittest(
44+
name = "test_example_delegate",
45+
srcs = [
46+
"test_example_delegate.py",
47+
],
48+
deps = [
49+
":example_partitioner",
50+
":example_quantizer",
51+
"//caffe2:torch",
52+
"//executorch/exir:delegate",
53+
"//executorch/exir:lib",
54+
"//executorch/exir/backend:backend_api",
55+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
56+
"//pytorch/vision:torchvision",
57+
],
58+
)
248 KB
Loading
309 KB
Loading
Loading

backends/example/example_backend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from typing import final, List
9+
10+
from executorch.backends.example.example_backend_delegate_passes.merge_to_dim_pass import (
11+
MergeToDimPass,
12+
)
13+
from executorch.backends.example.example_backend_delegate_passes.permute_memory_formats_pass import (
14+
PermuteMemoryFormatsPass,
15+
)
16+
17+
from executorch.exir.backend.backend_details import (
18+
BackendDetails,
19+
ExportedProgram,
20+
PreprocessResult,
21+
)
22+
from executorch.exir.backend.compile_spec_schema import CompileSpec
23+
24+
25+
@final
26+
class TosaBackend(BackendDetails):
27+
@staticmethod
28+
def preprocess(
29+
edge_program: ExportedProgram,
30+
compile_specs: List[CompileSpec],
31+
) -> PreprocessResult:
32+
print("entering the lowerable parts in TosaBackend.preprocess....")
33+
34+
copy_edge_program = copy.deepcopy(edge_program)
35+
copy_edge_program._transform(
36+
PermuteMemoryFormatsPass(),
37+
MergeToDimPass(),
38+
)
39+
processed_bytes = str(copy_edge_program.graph)
40+
return PreprocessResult(bytes(processed_bytes, encoding="utf8"))
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "lib",
5+
srcs = [
6+
"merge_to_dim_pass.py",
7+
"permute_memory_formats_pass.py",
8+
],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/backends/example/example_operators:example_operators_lib",
12+
"//executorch/exir:dim_order_utils",
13+
"//executorch/exir:pass_base",
14+
"//executorch/exir/dialects:lib",
15+
],
16+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.dim_order_utils import get_dim_order
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class MergeToDimPass(ExportPass):
15+
"""
16+
This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True.
17+
Example:
18+
# Done for 1 to 1
19+
before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
20+
after pass: x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out
21+
22+
# Not Done for 1 to N
23+
before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
24+
|-------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
25+
after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
26+
|--------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
27+
28+
# Not Done for N to 1
29+
before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
30+
y -> to_dim(channel_last) -> conv -> to_dim_(contiguous) ---------|
31+
after pass: x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out
32+
y -> to_dim(channel_last) -> conv-----|
33+
34+
# Not Done for N to N
35+
"""
36+
37+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
38+
for node in graph_module.graph.nodes:
39+
if node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
40+
# print(node, node.args, list(node.users), list(list(node.users)[0].args))
41+
if len(node.users) == 1 and len(list(node.users)[0].args) == 2:
42+
args_map = {}
43+
node_kwargs = node.args[-1]
44+
node_users = list(node.users)
45+
46+
in_to_dim_node_dim_order = node_kwargs["dim_order"]
47+
in_to_dim_node_dtype = node_kwargs["dtype"]
48+
out_to_dim_node = node_users[0]
49+
out_to_dim_node_kwargs = out_to_dim_node.args[-1]
50+
out_to_dim_node_dim_order = out_to_dim_node_kwargs["dim_order"]
51+
out_to_dim_node_dtype = out_to_dim_node_kwargs["dtype"]
52+
53+
if (
54+
in_to_dim_node_dtype == out_to_dim_node_dtype
55+
and in_to_dim_node_dim_order
56+
== get_dim_order(torch.channels_last, 4)
57+
and out_to_dim_node_dim_order
58+
== get_dim_order(torch.contiguous_format, 4)
59+
):
60+
61+
out_to_dim_node_users = list(out_to_dim_node.users)
62+
assert len(out_to_dim_node_users) == 1
63+
out_to_dim_node_user = out_to_dim_node_users[0]
64+
args_map[out_to_dim_node] = node.args[0]
65+
out_to_dim_node_user_new_args = [
66+
args_map[out_to_dim_node] if arg in args_map else arg
67+
for arg in out_to_dim_node_user.args
68+
]
69+
print("out_to_dim_node_user.args: ", out_to_dim_node_user.args)
70+
print(
71+
"out_to_dim_node_user_new_args: ",
72+
out_to_dim_node_user_new_args,
73+
)
74+
out_to_dim_node_user.args = tuple(out_to_dim_node_user_new_args)
75+
76+
graph_module.erase_node(out_to_dim_node)
77+
graph_module.erase_node(node)
78+
# TODO: Handle other merging rules, including 1->N, N->1, N->N
79+
return PassResult(graph_module, True)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import chain
8+
9+
import torch
10+
from executorch.backends.example.example_operators.ops import module_to_annotator
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dim_order_utils import get_dim_order
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
15+
16+
17+
class PermuteMemoryFormatsPass(ExportPass):
18+
"""
19+
This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True.
20+
Example 1:
21+
before pass: x -> conv -> out
22+
after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
23+
24+
before pass: x -> conv -> conv -> out
25+
after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
26+
27+
before pass: x -> conv -> linear -> out
28+
after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> linear -> to_dim_(contiguous) -> out
29+
"""
30+
31+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
32+
for pattern in list(module_to_annotator.keys()):
33+
pattern_op = module_to_annotator[pattern]
34+
if pattern_op.permuate_memory_format:
35+
partitions = find_sequential_partitions(
36+
graph_module,
37+
pattern,
38+
)
39+
for partition in partitions:
40+
# Some unpacking logic to get a flatten exit nodes list
41+
output_nodes = [
42+
node
43+
for node in partition[0].output_nodes
44+
if node.op != "placeholder"
45+
]
46+
exit_nodes = [output_node.users for output_node in output_nodes]
47+
exit_nodes = list(chain.from_iterable(exit_nodes))
48+
49+
"""
50+
# Step 1. Insert to_dim op when exit the pattern
51+
# for example, if the pattern is conv, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out when permute memory format
52+
# for x -> conv -> conv -> out, it will become x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out
53+
"""
54+
for exit_node in exit_nodes:
55+
with graph_module.graph.inserting_before(exit_node):
56+
# Handle the case when the pattern output is also the graph output,
57+
# like, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out
58+
if exit_node.op == "output":
59+
exit_to_dim_op = graph_module.graph.call_function(
60+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
61+
exit_node.args,
62+
{
63+
"dtype": torch.float64,
64+
"dim_order": get_dim_order(
65+
torch.contiguous_format, 4
66+
),
67+
},
68+
)
69+
# Insert to_dim op and it'll be the return op
70+
_ = graph_module.graph.output(exit_to_dim_op)
71+
# Remove the old return op.
72+
graph_module.graph.erase_node(exit_node)
73+
# Handle the case when the pattern output is intermediate output,
74+
# like, x -> conv -> conv -> out will become x -> conv -> to_dim(contiguous) -> conv -> out
75+
elif exit_node.op == "call_function":
76+
exit_node_args = []
77+
for exit_node_arg in exit_node.args:
78+
if (
79+
isinstance(exit_node_arg, torch.fx.Node)
80+
and exit_node_arg.op != "placeholder"
81+
):
82+
exit_to_dim_op = graph_module.graph.call_function(
83+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
84+
(exit_node_arg,),
85+
{
86+
"dtype": torch.float64,
87+
"dim_order": get_dim_order(
88+
torch.contiguous_format, 4
89+
),
90+
},
91+
)
92+
exit_node_args.append(exit_to_dim_op)
93+
else:
94+
exit_node_args.append(exit_node_arg)
95+
exit_node.args = list(exit_node_args)
96+
97+
"""
98+
# Step 2. Insert to_dim op when enter the pattern. After the first step, we already have to_dim(default) when exiting the pattern.
99+
# Now we need to insert to_dim(channel_last) when enter the pattern.
100+
# for example, if the pattern is conv, x -> conv -> to_dim(contiguous) -> out will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out
101+
# for x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out, it will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out
102+
"""
103+
# create the input_node and the to_dim_op map
104+
# for example, if the pattern is conv, x -> conv -> out, node
105+
input_node_map = {} # key: input_node, value: to_dim_op
106+
to_dim_op_set = set()
107+
for input_node in partition[0].input_nodes:
108+
with graph_module.graph.inserting_after(input_node):
109+
to_dim_op = graph_module.graph.call_function(
110+
# Insert the to_dim op and update input_node_map
111+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
112+
(
113+
input_node,
114+
{
115+
"dtype": torch.float64,
116+
"dim_order": get_dim_order(
117+
torch.channels_last, 4
118+
),
119+
},
120+
),
121+
)
122+
input_node_map[input_node] = to_dim_op
123+
to_dim_op_set.add(to_dim_op)
124+
125+
# Update the args to the new to_dim op, skip if it's already set
126+
for input_node in partition[0].input_nodes:
127+
for user in list(input_node.users):
128+
# if user is in to_dim_op_set, it means the user's arg is already set to_dim op
129+
if user not in to_dim_op_set:
130+
user_new_arg = [
131+
input_node_map[user_arg]
132+
if user_arg in input_node_map
133+
else user_arg
134+
for user_arg in user.args
135+
]
136+
# Update input node's users arg
137+
user.args = tuple(user_new_arg)
138+
139+
# Ensure the graph is still valid
140+
graph_module.graph.lint()
141+
graph_module.recompile()
142+
return PassResult(graph_module, True)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "example_operators_lib",
5+
srcs = [
6+
"adaptive_avg_pool2d.py",
7+
"add.py",
8+
"conv2d.py",
9+
"conv_relu.py",
10+
"dropout.py",
11+
"flatten.py",
12+
"linear.py",
13+
"op_base.py",
14+
"ops.py",
15+
"utils.py",
16+
],
17+
deps = [
18+
"//caffe2:torch",
19+
],
20+
)

0 commit comments

Comments
 (0)