Skip to content

Add priliminary support for lifted graphs #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/nightly.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
dev20230831
dev20230907
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0
2.2.0
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.0
0.17.0
90 changes: 66 additions & 24 deletions backends/xnnpack/README.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,91 @@
# Executorch XNNPACK Delegate

This subtree contains the XNNPACK Delegate implementation for Executorch. XNNPACK is an optimized library of neural network inference operators for ARM and x86 platforms. It is an open source projected used by PyTorch. The delegate is the mechanism for leveraging the XNNPACK Library to accelerate operators running on CPU.
This subtree contains the XNNPACK Delegate implementation for Executorch.
XNNPACK is an optimized library of neural network inference operators for ARM
and x86 CPUs. It is an open source project used by PyTorch. The delegate is the
mechanism for leveraging the XNNPACK library to accelerate operators running on
CPU.

## Layout
- `runtime/` : Runtime logic use at inference. This contains all the cpp files used to build the runtime graph and execute the XNNPACK model
- `partition/`: Partitioner is used to identify operators in model's graph that are suitable for lowering to XNNPACK delegate
- `support_patterns.py`: Contains list of captured graph patterns that are suitable for XNNPack
- `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns for XNNPACK lowering
- `passes/`: Contains passes which are used before preprocessing to prepare the graph for XNNPACK lowering
- `runtime/` : Runtime logic used at inference. This contains all the cpp files
used to build the runtime graph and execute the XNNPACK model
- `partition/`: Partitioner is used to identify operators in model's graph that
are suitable for lowering to XNNPACK delegate
- `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns
for XNNPACK lowering
- `configs.py`: Contains lists of op/modules for XNNPACK lowering
- `passes/`: Contains passes which are used before preprocessing to prepare the
graph for XNNPACK lowering
- `operators`: the directory to store all of op visitors
- `node_visitor.py`: Implementation of serializing each lowerable operator node
- `node_visitor.py`: Implementation of serializing each lowerable operator
node
- ...
- `serialization/`: Contains files related to serializing the XNNPACK graph representation of the PyTorch model
- `serialization/`: Contains files related to serializing the XNNPACK graph
representation of the PyTorch model
- `schema.fbs`: Flatbuffer schema of serialization format
- `xnnpack_graph_schema.py`: Python dataclasses mirroring the flatbuffer schema
- `xnnpack_graph_serialize`: Implementation for serializing dataclasses from graph schema to flatbuffer
- `xnnpack_graph_schema.py`: Python dataclasses mirroring the flatbuffer
schema
- `xnnpack_graph_serialize`: Implementation for serializing dataclasses
from graph schema to flatbuffer
- `test/`: Tests for XNNPACK Delegate
- `test_xnnpack.py`: end-to-end tests operator implementation of the xnnpack delegate
- `test_xnnpack_passes.py`: Tests for graph passes used by xnnpack
- `xnnpack_preprocess.py`: Contains preprocess implementation which is called by `to_backend` on the graph or subgraph of a model returning a preprocessed blob responsible for executing the graph or subgraph at runtime
- `xnnpack_preprocess.py`: Contains preprocess implementation which is called
by `to_backend` on the graph or subgraph of a model returning a preprocessed
blob responsible for executing the graph or subgraph at runtime

## Help & Improvements
If you have problems or questions, or have suggestions for ways to make implementation and testing better, please contact [Max Ren](https://fb.workplace.com/profile.php?id=100045762936437), [Digant Desai](https://fb.workplace.com/profile.php?id=100068306324819), or [Kimish Patel](https://fb.workplace.com/profile.php?id=100030094785558) on the PyTorch Edge team.
If you have problems or questions, or have suggestions for ways to make
implementation and testing better, please reach out to the PyTorch Edge team or
create an issue on [github](https://www.github.com/pytorch/executorch/issues).

## Contributing

Please follow the following these steps and guidelines when adding a new operator implementation to this library. The goals of these guidelines are to
- Make it straightforward to add new XNNPack operators.
- Ensure that the newly added operators are of high quality, and are easy to maintain
- Make it easy for users to find available available operator implementations, and to trust in their quality and behavioral stability.
Please follow the following steps and guidelines when adding a new operator
implementation to this library. The goals of these guidelines are to
- Make it straightforward to add new XNNPACK operators.
- Ensure that the newly added operators are of high quality, and are easy to
maintain
- Make it easy for users to find available operator implementations, and to
trust in their quality and behavioral stability.

### AoT and Serialization Overview
#### Serialization:
XNNPACK delegate uses flatbuffer to serialize its nodes and values. In order to add [preprocessing](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/xnnpack_preprocess.py?lines=357) support for a new operator, we must add the operator in both the flatbuffer [schema](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/executorch/backends/xnnpack/serialization/schema.fbs), as well as the mirrored python [data class](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/executorch/backends/xnnpack/serialization/xnnpack_graph_schema.py). These tables are based on the arguments to the XNNPACK Subgraph APIs. These APIs can be found [here](https://www.internalfb.com/code/fbsource/[9a71ca4ec2a5284867562112946ac61f5660b881]/fbcode/xplat/third-party/XNNPACK/XNNPACK/include/xnnpack.h?lines=722-729). We essentially serialize all the static arguments we need to call `define_{new operator}()`.
XNNPACK delegate uses flatbuffer to serialize its nodes and values. In order to
add
[preprocessing](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/xnnpack_preprocess.py)
support for a new operator, we must add the operator in both the flatbuffer
[schema](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/schema.fbs),
as well as the mirrored python [data
class](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/xnnpack_graph_schema.py).
These tables are based on the arguments to the XNNPACK Subgraph APIs. These
APIs can be found
[here](https://github.com/google/xnnpack/blob/master/include/xnnpack.h). We
essentially serialize all the static arguments we need to call `define_{new
operator}()`.

#### AoT Preprocess:
To add logic to preprocess new operators for the XNNPACK Delegate, we can create new node_visitors that perform the serialization of the new operator. An example can be found [here](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/serialization/node_visitor.py?lines=286-314). The function of these node_visitors is to serialize all the data we define to need in the schema above.
To add logic to preprocess new operators for the XNNPACK Delegate, we can
create new node_visitors that perform the serialization of the new operator. An
example can be found [here](). The function of these node_visitors is to
serialize all the data we define to need in the schema above.

#### AoT Partitioner:
Xnnpack Partitioner is used to selected the pattern (like the linear module graph) in a big graph such that the selected nodes will be delegated to xnnpack. To support a new op (for example, sigmoid), add the corresponding pattern to [partition/support_pattern.py](https://www.internalfb.com/code/fbsource/fbcode/executorch/backends/xnnpack/partition/support_patterns.py?lines=121-130), which captures the sigmoid op. Then expand the [self.pattern in xnnpack_partitioner.py](https://www.internalfb.com/code/fbsource/[8a7869f9d150dd6272b56d04e2d65029a92a1550]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=23-25) with the new pattern.
XnnpackPartitioner is used to select the pattern (like the linear module
graph) in a big graph such that the selected nodes will be delegated to
XNNPACK. To support a new op (for example, sigmoid), add the corresponding op
or module to the
[config.py](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/partition/configs.py),
which captures the sigmoid op.

#### How does it work?
- Tag the nodes: in the xnnpack partitioner, there is a field called [self.patterns](https://www.internalfb.com/code/fbsource/[50683ef7e3e9baf61e1d7719e19990db3a26bbfe]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=21-29)(), which lists all ops that are supported by the current xnnpack backend in executorch. When call [xnnpackpartitioner.partition()](https://www.internalfb.com/code/fbsource/[50683ef7e3e9baf61e1d7719e19990db3a26bbfe]/fbcode/executorch/backends/xnnpack/partition/xnnpack_partitioner.py?lines=42), it will tag all the nodes that matches the patterns listed in self.pattern
- Lower the nodes; when we call `to_backend(graph_module, XnnpackPartitioner)`, it will loop through all the tagged nodes, and lower the group with the same tag.
- Tag the nodes: in the XNNPACK partitioner's config, which lists all ops that
are supported by the current XNNPACK backend in executorch. When call
`XnnpackPartitioner.partition()`, it will tag all the nodes that matches the
patterns listed in self.pattern
- Lower the nodes; when we call `to_backend(graph_module, XnnpackPartitioner)`,
it will loop through all the tagged nodes, and lower the group with the same
tag.


#### Adding Tests for newly minted operators
To test newly added operators, we can add unit tests in: [test_xnnpack.py](https://www.internalfb.com/code/fbsource/[d9018f0841600b95256187b9a08aeab2aa8b3c11]/fbcode/executorch/backends/xnnpack/test/test_xnnpack.py)
To test newly added operators, we can add unit tests in:
[tests](https://github.com/pytorch/executorch/tree/main/backends/xnnpack/test)
88 changes: 61 additions & 27 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
from executorch.backends.xnnpack.utils.utils import (
check_or_raise,
get_input_node,
get_param_tensor,
is_param_node,
PERM_NCHW_TO_NHWC,
)

from executorch.backends.xnnpack.utils.xnnpack_constants import (
XNN_INVALID_VALUE_ID,
XNN_VALUE_FLAG_EXTERNAL_INPUT,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram

XNN_TYPE_MAP = {
torch.float32: XNNDatatype.xnn_datatype_fp32,
Expand Down Expand Up @@ -75,8 +74,21 @@ class NodeVisitor:
serializing them using the xnnpack serialization schema defined
"""

def __init__(self, external_ids) -> None:
self.external_ids = external_ids or {}
def __init__(
self,
exported_program: ExportedProgram,
external_ids: Dict,
) -> None:
self._external_ids = external_ids or {}
self._exported_program = exported_program or None

@property
def external_ids(self) -> Dict:
return self._external_ids

@property
def exported_program(self) -> ExportedProgram:
return self._exported_program

def is_graph_input(self, tensor: torch.fx.Node) -> bool:
"""
Expand All @@ -85,7 +97,9 @@ def is_graph_input(self, tensor: torch.fx.Node) -> bool:
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
return tensor.op == "placeholder"
return tensor.op == "placeholder" and not is_param_node(
self.exported_program, tensor
)

def is_graph_output(self, tensor: torch.fx.Node) -> bool:
"""
Expand Down Expand Up @@ -130,30 +144,49 @@ def gen_ids_and_flags(
# This will break if we change the way q/dq are partitioned

# Tensor can still be input if its quantizing node is an input
if self.is_graph_input(tensor) or (
quant_params.is_input if quant_params else False
):
is_input = self.is_graph_input(tensor) or (
quant_params.is_input
and not is_param_node(self.exported_program, quant_params.q_input)
if quant_params
else False
)

# Tensor can still be output if its quantizing node is an output
is_output = self.is_graph_output(tensor) or (
quant_params.is_output if quant_params else False
)

if is_input:
tensor_input = tensor
if quant_params:
if quant_params.is_input and not self.is_graph_input(tensor):
tensor_input = quant_params.q_input
if (
quant_params
and quant_params.is_input
and not is_param_node(self.exported_program, quant_params.q_input)
and not self.is_graph_input(tensor)
):
tensor_input = quant_params.q_input

assert (
tensor_input in self.external_ids.keys()
), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}"

ext_id = self.external_ids[tensor_input].external_id
xnn_graph.input_ids.append(id_out)
flag = self.external_ids[tensor_input].io_type
# Tensor can still be output if its quantizing node is an output
elif self.is_graph_output(tensor) or (
quant_params.is_output if quant_params else False
):

elif is_output:
tensor_output = tensor
if quant_params:
if quant_params.is_output and not self.is_graph_output(tensor):
tensor_output = list(tensor.users)[0]
if (
quant_params
and quant_params.is_output
and not self.is_graph_output(tensor)
):
tensor_output = list(tensor.users)[0]

assert (
tensor_output in self.external_ids.keys()
), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}"
), f"Tensor {tensor_output} is_output. ext_ids: {self.external_ids.keys()}"

ext_id = self.external_ids[tensor_output].external_id
xnn_graph.output_ids.append(id_out)
flag = self.external_ids[tensor_output].io_type
Expand Down Expand Up @@ -331,7 +364,7 @@ def get_serialized_buffer(
"""
# The get_attr node is the input to quant_params.
get_attr_node = tensor if quant_params is None else quant_params.q_input
if get_attr_node.op != "get_attr":
if not is_param_node(self.exported_program, get_attr_node):
check_or_raise(
not swap_nc_for_depthwise_weights,
"Swapping N and C dimensions is only valid for constant data tensors",
Expand All @@ -343,9 +376,10 @@ def get_serialized_buffer(
"Internal Error: const_buffer and buffer_sizes length mismatch",
)
buffer_idx = len(xnn_graph.constant_buffer)
const_val = getattr(
get_attr_node.graph.owning_module, get_attr_node.target
).contiguous()
const_val = get_param_tensor(self.exported_program, get_attr_node)
assert const_val is not None and isinstance(const_val, torch.Tensor)
const_val = const_val.contiguous()

# Quantize buffer if static data is indeed quantized
if quant_params is not None and not quant_params.is_dynamic:
const_val = quant_params.quantize_tensor(const_val).contiguous()
Expand All @@ -358,9 +392,9 @@ def get_serialized_buffer(
dims=((1, 0) + tuple(range(2, const_val.dim())))
).contiguous()
if convert_to_nhwc:
# pyre-ignore[28] Unexpected keyword argument `memory_format`
const_val = const_val.to(memory_format=torch.channels_last)

# pyre-ignore
array_type = ctypes.c_char * const_val.untyped_storage().nbytes()
array = ctypes.cast(
const_val.untyped_storage().data_ptr(),
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ python_library(
"prelu_reshape_pass.py",
"remove_getitem_op.py",
"tag_implicit_q_dq_pass.py",
"xnnpack_pass.py",
],
deps = [
"//caffe2:torch",
Expand Down
68 changes: 55 additions & 13 deletions backends/xnnpack/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
)
Expand All @@ -15,19 +17,59 @@
from executorch.backends.xnnpack.passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass

from executorch.exir.pass_base import ExportPass

from executorch.exir.passes import PassManager
from executorch.exir.passes.const_prop_pass import ConstPropPass
from torch._export.pass_base import PassType

xnnpack_delegation_passes = PassManager(
passes=[
ConvertToLinearPass(),
ConstPropPass(),
FuseBatchNormWithConvPass(),
RemoveGetItemPass(),
Conv1dUnsqueezePass(),
PReLUReshapePass(),
ChannelsLastTaggedReshapePass(),
TagImplicitQDqPass(),
]
)
from torch.export import ExportedProgram


class XNNPACKPassManager:
def __init__(
self, exported_program: ExportedProgram, passes: Optional[List[PassType]] = None
) -> None:
"""
A helper class to run multiple XNNPack passes on a program
If passes list is empty, all passes in XNNPACK will be run.
Else only run passes in the list will be run.
"""
self._exported_program = exported_program

if not passes:
# All the XNNPACK passes
self.passes = [
ConvertToLinearPass,
ConstPropPass,
FuseBatchNormWithConvPass,
RemoveGetItemPass,
Conv1dUnsqueezePass,
PReLUReshapePass,
ChannelsLastTaggedReshapePass,
TagImplicitQDqPass,
]
else:
self.passes = passes

@property
def exported_program(self) -> ExportedProgram:
return self._exported_program

def transform(self) -> ExportedProgram:
"""
Returns a transformed ExportedProgram
"""
ep = self.exported_program
for pass_ in self.passes:
if issubclass(pass_, XNNPACKPass):
transform_pass = pass_(ep)
elif issubclass(pass_, ExportPass):
transform_pass = pass_()
else:
raise RuntimeError(
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
)
ep = ep._transform(transform_pass)
return ep
Loading