Skip to content

Commit a5ea63a

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add priliminary support for lifted graphs (#199)
Summary: Pull Request resolved: #199 Reviewed By: mcr229 Differential Revision: D48710125
1 parent 770e4cc commit a5ea63a

21 files changed

+560
-308
lines changed

backends/xnnpack/README.md

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,91 @@
11
# Executorch XNNPACK Delegate
22

3-
This subtree contains the XNNPACK Delegate implementation for Executorch. XNNPACK is an optimized library of neural network inference operators for ARM and x86 platforms. It is an open source projected used by PyTorch. The delegate is the mechanism for leveraging the XNNPACK Library to accelerate operators running on CPU.
3+
This subtree contains the XNNPACK Delegate implementation for Executorch.
4+
XNNPACK is an optimized library of neural network inference operators for ARM
5+
and x86 CPUs. It is an open source project used by PyTorch. The delegate is the
6+
mechanism for leveraging the XNNPACK library to accelerate operators running on
7+
CPU.
48

59
## Layout
6-
- `runtime/` : Runtime logic use at inference. This contains all the cpp files used to build the runtime graph and execute the XNNPACK model
7-
- `partition/`: Partitioner is used to identify operators in model's graph that are suitable for lowering to XNNPACK delegate
8-
- `support_patterns.py`: Contains list of captured graph patterns that are suitable for XNNPack
9-
- `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns for XNNPACK lowering
10-
- `passes/`: Contains passes which are used before preprocessing to prepare the graph for XNNPACK lowering
10+
- `runtime/` : Runtime logic used at inference. This contains all the cpp files
11+
used to build the runtime graph and execute the XNNPACK model
12+
- `partition/`: Partitioner is used to identify operators in model's graph that
13+
are suitable for lowering to XNNPACK delegate
14+
- `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns
15+
for XNNPACK lowering
16+
- `configs.py`: Contains lists of op/modules for XNNPACK lowering
17+
- `passes/`: Contains passes which are used before preprocessing to prepare the
18+
graph for XNNPACK lowering
1119
- `operators`: the directory to store all of op visitors
12-
- `node_visitor.py`: Implementation of serializing each lowerable operator node
20+
- `node_visitor.py`: Implementation of serializing each lowerable operator
21+
node
1322
- ...
14-
- `serialization/`: Contains files related to serializing the XNNPACK graph representation of the PyTorch model
23+
- `serialization/`: Contains files related to serializing the XNNPACK graph
24+
representation of the PyTorch model
1525
- `schema.fbs`: Flatbuffer schema of serialization format
16-
- `xnnpack_graph_schema.py`: Python dataclasses mirroring the flatbuffer schema
17-
- `xnnpack_graph_serialize`: Implementation for serializing dataclasses from graph schema to flatbuffer
26+
- `xnnpack_graph_schema.py`: Python dataclasses mirroring the flatbuffer
27+
schema
28+
- `xnnpack_graph_serialize`: Implementation for serializing dataclasses
29+
from graph schema to flatbuffer
1830
- `test/`: Tests for XNNPACK Delegate
19-
- `test_xnnpack.py`: end-to-end tests operator implementation of the xnnpack delegate
20-
- `test_xnnpack_passes.py`: Tests for graph passes used by xnnpack
21-
- `xnnpack_preprocess.py`: Contains preprocess implementation which is called by `to_backend` on the graph or subgraph of a model returning a preprocessed blob responsible for executing the graph or subgraph at runtime
31+
- `xnnpack_preprocess.py`: Contains preprocess implementation which is called
32+
by `to_backend` on the graph or subgraph of a model returning a preprocessed
33+
blob responsible for executing the graph or subgraph at runtime
2234

2335
## Help & Improvements
24-
If you have problems or questions, or have suggestions for ways to make implementation and testing better, please contact [Max Ren](https://fb.workplace.com/profile.php?id=100045762936437), [Digant Desai](https://fb.workplace.com/profile.php?id=100068306324819), or [Kimish Patel](https://fb.workplace.com/profile.php?id=100030094785558) on the PyTorch Edge team.
36+
If you have problems or questions, or have suggestions for ways to make
37+
implementation and testing better, please reach out to the PyTorch Edge team or
38+
create an issue on [github](https://www.github.com/pytorch/executorch/issues).
2539

2640
## Contributing
2741

28-
Please follow the following these steps and guidelines when adding a new operator implementation to this library. The goals of these guidelines are to
29-
- Make it straightforward to add new XNNPack operators.
30-
- Ensure that the newly added operators are of high quality, and are easy to maintain
31-
- Make it easy for users to find available available operator implementations, and to trust in their quality and behavioral stability.
42+
Please follow the following steps and guidelines when adding a new operator
43+
implementation to this library. The goals of these guidelines are to
44+
- Make it straightforward to add new XNNPACK operators.
45+
- Ensure that the newly added operators are of high quality, and are easy to
46+
maintain
47+
- Make it easy for users to find available operator implementations, and to
48+
trust in their quality and behavioral stability.
3249

3350
### AoT and Serialization Overview
3451
#### Serialization:
35-
XNNPACK delegate uses flatbuffer to serialize its nodes and values. In order to add [preprocessing](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/xnnpack_preprocess.py?lines=357) support for a new operator, we must add the operator in both the flatbuffer [schema](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/executorch/backends/xnnpack/serialization/schema.fbs), as well as the mirrored python [data class](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/executorch/backends/xnnpack/serialization/xnnpack_graph_schema.py). These tables are based on the arguments to the XNNPACK Subgraph APIs. These APIs can be found [here](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/xplat/third-party/XNNPACK/XNNPACK/include/xnnpack.h?lines=722-729). We essentially serialize all the static arguments we need to call `define_{new operator}()`.
52+
XNNPACK delegate uses flatbuffer to serialize its nodes and values. In order to
53+
add
54+
[preprocessing](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/xnnpack_preprocess.py)
55+
support for a new operator, we must add the operator in both the flatbuffer
56+
[schema](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/schema.fbs),
57+
as well as the mirrored python [data
58+
class](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/xnnpack_graph_schema.py).
59+
These tables are based on the arguments to the XNNPACK Subgraph APIs. These
60+
APIs can be found
61+
[here](https://github.com/google/xnnpack/blob/master/include/xnnpack.h). We
62+
essentially serialize all the static arguments we need to call `define_{new
63+
operator}()`.
3664

3765
#### AoT Preprocess:
38-
To add logic to preprocess new operators for the XNNPACK Delegate, we can create new node_visitors that perform the serialization of the new operator. An example can be found [here](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/serialization/node_visitor.py?lines=286-314). The function of these node_visitors is to serialize all the data we define to need in the schema above.
66+
To add logic to preprocess new operators for the XNNPACK Delegate, we can
67+
create new node_visitors that perform the serialization of the new operator. An
68+
example can be found [here](). The function of these node_visitors is to
69+
serialize all the data we define to need in the schema above.
3970

4071
#### AoT Partitioner:
41-
Xnnpack Partitioner is used to selected the pattern (like the linear module graph) in a big graph such that the selected nodes will be delegated to xnnpack. To support a new op (for example, sigmoid), add the corresponding pattern to [partition/support_pattern.py](https://www.internalfb.com/code/fbsource/fbcode/executorch/backends/xnnpack/partition/support_patterns.py?lines=121-130), which captures the sigmoid op. Then expand the [self.pattern in xnnpack_partitioner.py](https://www.internalfb.com/code/fbsource/[8a7869f9d150dd6272b56d04e2d65029a92a1550]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=23-25) with the new pattern.
72+
XnnpackPartitioner is used to select the pattern (like the linear module
73+
graph) in a big graph such that the selected nodes will be delegated to
74+
XNNPACK. To support a new op (for example, sigmoid), add the corresponding op
75+
or module to the
76+
[config.py](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/partition/configs.py),
77+
which captures the sigmoid op.
4278

4379
#### How does it work?
44-
- Tag the nodes: in the xnnpack partitioner, there is a field called [self.patterns](https://www.internalfb.com/code/fbsource/[50683ef7e3e9baf61e1d7719e19990db3a26bbfe]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=21-29)(), which lists all ops that are supported by the current xnnpack backend in executorch. When call [xnnpackpartitioner.partition()](https://www.internalfb.com/code/fbsource/[50683ef7e3e9baf61e1d7719e19990db3a26bbfe]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=42), it will tag all the nodes that matches the patterns listed in self.pattern
45-
- Lower the nodes; when we call `to_backend(graph_module, XnnpackPartitioner)`, it will loop through all the tagged nodes, and lower the group with the same tag.
80+
- Tag the nodes: in the XNNPACK partitioner's config, which lists all ops that
81+
are supported by the current XNNPACK backend in executorch. When call
82+
`XnnpackPartitioner.partition()`, it will tag all the nodes that matches the
83+
patterns listed in self.pattern
84+
- Lower the nodes; when we call `to_backend(graph_module, XnnpackPartitioner)`,
85+
it will loop through all the tagged nodes, and lower the group with the same
86+
tag.
4687

4788

4889
#### Adding Tests for newly minted operators
49-
To test newly added operators, we can add unit tests in: [test_xnnpack.py](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/test/test_xnnpack.py)
90+
To test newly added operators, we can add unit tests in:
91+
[tests](https://github.com/pytorch/executorch/tree/main/backends/xnnpack/test)

backends/xnnpack/operators/node_visitor.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030
from executorch.backends.xnnpack.utils.utils import (
3131
check_or_raise,
3232
get_input_node,
33+
get_param_tensor,
34+
is_param_node,
3335
PERM_NCHW_TO_NHWC,
3436
)
3537

36-
from executorch.backends.xnnpack.utils.xnnpack_constants import (
37-
XNN_INVALID_VALUE_ID,
38-
XNN_VALUE_FLAG_EXTERNAL_INPUT,
39-
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
40-
)
38+
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
4139
from executorch.exir.dialects._ops import ops as exir_ops
40+
from torch.export import ExportedProgram
4241

4342
XNN_TYPE_MAP = {
4443
torch.float32: XNNDatatype.xnn_datatype_fp32,
@@ -75,8 +74,21 @@ class NodeVisitor:
7574
serializing them using the xnnpack serialization schema defined
7675
"""
7776

78-
def __init__(self, external_ids) -> None:
79-
self.external_ids = external_ids or {}
77+
def __init__(
78+
self,
79+
exported_program: ExportedProgram,
80+
external_ids: Dict,
81+
) -> None:
82+
self._external_ids = external_ids or {}
83+
self._exported_program = exported_program or None
84+
85+
@property
86+
def external_ids(self) -> Dict:
87+
return self._external_ids
88+
89+
@property
90+
def exported_program(self) -> ExportedProgram:
91+
return self._exported_program
8092

8193
def is_graph_input(self, tensor: torch.fx.Node) -> bool:
8294
"""
@@ -85,7 +97,9 @@ def is_graph_input(self, tensor: torch.fx.Node) -> bool:
8597
Args:
8698
tensor: EdgeIR Tensor that is being checked for graph input
8799
"""
88-
return tensor.op == "placeholder"
100+
return tensor.op == "placeholder" and not is_param_node(
101+
self.exported_program, tensor
102+
)
89103

90104
def is_graph_output(self, tensor: torch.fx.Node) -> bool:
91105
"""
@@ -130,30 +144,49 @@ def gen_ids_and_flags(
130144
# This will break if we change the way q/dq are partitioned
131145

132146
# Tensor can still be input if its quantizing node is an input
133-
if self.is_graph_input(tensor) or (
134-
quant_params.is_input if quant_params else False
135-
):
147+
is_input = self.is_graph_input(tensor) or (
148+
quant_params.is_input
149+
and not is_param_node(self.exported_program, quant_params.q_input)
150+
if quant_params
151+
else False
152+
)
153+
154+
# Tensor can still be output if its quantizing node is an output
155+
is_output = self.is_graph_output(tensor) or (
156+
quant_params.is_output if quant_params else False
157+
)
158+
159+
if is_input:
136160
tensor_input = tensor
137-
if quant_params:
138-
if quant_params.is_input and not self.is_graph_input(tensor):
139-
tensor_input = quant_params.q_input
161+
if (
162+
quant_params
163+
and quant_params.is_input
164+
and not is_param_node(self.exported_program, quant_params.q_input)
165+
and not self.is_graph_input(tensor)
166+
):
167+
tensor_input = quant_params.q_input
168+
140169
assert (
141170
tensor_input in self.external_ids.keys()
142171
), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}"
172+
143173
ext_id = self.external_ids[tensor_input].external_id
144174
xnn_graph.input_ids.append(id_out)
145175
flag = self.external_ids[tensor_input].io_type
146-
# Tensor can still be output if its quantizing node is an output
147-
elif self.is_graph_output(tensor) or (
148-
quant_params.is_output if quant_params else False
149-
):
176+
177+
elif is_output:
150178
tensor_output = tensor
151-
if quant_params:
152-
if quant_params.is_output and not self.is_graph_output(tensor):
153-
tensor_output = list(tensor.users)[0]
179+
if (
180+
quant_params
181+
and quant_params.is_output
182+
and not self.is_graph_output(tensor)
183+
):
184+
tensor_output = list(tensor.users)[0]
185+
154186
assert (
155187
tensor_output in self.external_ids.keys()
156-
), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}"
188+
), f"Tensor {tensor_output} is_output. ext_ids: {self.external_ids.keys()}"
189+
157190
ext_id = self.external_ids[tensor_output].external_id
158191
xnn_graph.output_ids.append(id_out)
159192
flag = self.external_ids[tensor_output].io_type
@@ -331,7 +364,7 @@ def get_serialized_buffer(
331364
"""
332365
# The get_attr node is the input to quant_params.
333366
get_attr_node = tensor if quant_params is None else quant_params.q_input
334-
if get_attr_node.op != "get_attr":
367+
if not is_param_node(self.exported_program, get_attr_node):
335368
check_or_raise(
336369
not swap_nc_for_depthwise_weights,
337370
"Swapping N and C dimensions is only valid for constant data tensors",
@@ -343,9 +376,10 @@ def get_serialized_buffer(
343376
"Internal Error: const_buffer and buffer_sizes length mismatch",
344377
)
345378
buffer_idx = len(xnn_graph.constant_buffer)
346-
const_val = getattr(
347-
get_attr_node.graph.owning_module, get_attr_node.target
348-
).contiguous()
379+
const_val = get_param_tensor(self.exported_program, get_attr_node)
380+
assert const_val is not None and isinstance(const_val, torch.Tensor)
381+
const_val = const_val.contiguous()
382+
349383
# Quantize buffer if static data is indeed quantized
350384
if quant_params is not None and not quant_params.is_dynamic:
351385
const_val = quant_params.quantize_tensor(const_val).contiguous()
@@ -358,9 +392,9 @@ def get_serialized_buffer(
358392
dims=((1, 0) + tuple(range(2, const_val.dim())))
359393
).contiguous()
360394
if convert_to_nhwc:
395+
# pyre-ignore[28] Unexpected keyword argument `memory_format`
361396
const_val = const_val.to(memory_format=torch.channels_last)
362397

363-
# pyre-ignore
364398
array_type = ctypes.c_char * const_val.untyped_storage().nbytes()
365399
array = ctypes.cast(
366400
const_val.untyped_storage().data_ptr(),

backends/xnnpack/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ python_library(
1111
"prelu_reshape_pass.py",
1212
"remove_getitem_op.py",
1313
"tag_implicit_q_dq_pass.py",
14+
"xnnpack_pass.py",
1415
],
1516
deps = [
1617
"//caffe2:torch",

backends/xnnpack/passes/__init__.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import List, Optional
8+
79
from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
810
ChannelsLastTaggedReshapePass,
911
)
@@ -15,19 +17,59 @@
1517
from executorch.backends.xnnpack.passes.prelu_reshape_pass import PReLUReshapePass
1618
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
1719
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
20+
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
21+
22+
from executorch.exir.pass_base import ExportPass
1823

19-
from executorch.exir.passes import PassManager
2024
from executorch.exir.passes.const_prop_pass import ConstPropPass
25+
from torch._export.pass_base import PassType
2126

22-
xnnpack_delegation_passes = PassManager(
23-
passes=[
24-
ConvertToLinearPass(),
25-
ConstPropPass(),
26-
FuseBatchNormWithConvPass(),
27-
RemoveGetItemPass(),
28-
Conv1dUnsqueezePass(),
29-
PReLUReshapePass(),
30-
ChannelsLastTaggedReshapePass(),
31-
TagImplicitQDqPass(),
32-
]
33-
)
27+
from torch.export import ExportedProgram
28+
29+
30+
class XNNPACKPassManager:
31+
def __init__(
32+
self, exported_program: ExportedProgram, passes: Optional[List[PassType]] = None
33+
) -> None:
34+
"""
35+
A helper class to run multiple XNNPack passes on a program
36+
If passes list is empty, all passes in XNNPACK will be run.
37+
Else only run passes in the list will be run.
38+
"""
39+
self._exported_program = exported_program
40+
41+
if not passes:
42+
# All the XNNPACK passes
43+
self.passes = [
44+
ConvertToLinearPass,
45+
ConstPropPass,
46+
FuseBatchNormWithConvPass,
47+
RemoveGetItemPass,
48+
Conv1dUnsqueezePass,
49+
PReLUReshapePass,
50+
ChannelsLastTaggedReshapePass,
51+
TagImplicitQDqPass,
52+
]
53+
else:
54+
self.passes = passes
55+
56+
@property
57+
def exported_program(self) -> ExportedProgram:
58+
return self._exported_program
59+
60+
def transform(self) -> ExportedProgram:
61+
"""
62+
Returns a transformed ExportedProgram
63+
"""
64+
ep = self.exported_program
65+
for pass_ in self.passes:
66+
if issubclass(pass_, XNNPACKPass):
67+
transform_pass = pass_(ep)
68+
elif issubclass(pass_, ExportPass):
69+
transform_pass = pass_()
70+
else:
71+
raise RuntimeError(
72+
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
73+
)
74+
ep = ep._transform(transform_pass)
75+
return ep

0 commit comments

Comments
 (0)