Skip to content

always partition used static attr #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/qnnpack/QNNPackBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class QnnpackBackend final : public PyTorchBackendInterface {
weights_zp->buffer()->data(),
ScalarType::QUInt8,
runtime_allocator,
0,
pre_pad_bytes, // Not necessary to prepad but surpresses asan errors:
// D42179009
&zp_buf);

// Create + copy Weight Scales Tensor
Expand All @@ -152,7 +153,8 @@ class QnnpackBackend final : public PyTorchBackendInterface {
weights_scale->buffer()->data(),
ScalarType::Float,
runtime_allocator,
0,
pre_pad_bytes, // Not necessary to prepad but surpresses asan errors:
// D42179009
&scale_buf);

// Create Quantized Input Tensor
Expand Down
13 changes: 5 additions & 8 deletions backends/qnnpack/partition/qnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Callable, Dict, List, Optional, Union

import torch
from typing import Dict, List, Optional, Union

from executorch.backends.qnnpack.partition.support_patterns import (
get_dynamic_quant_addmm_with_view_copy_graph,
Expand All @@ -16,14 +14,13 @@
get_dynamic_quant_mm_without_view_copy_graph,
)
from executorch.backends.qnnpack.qnnpack_preprocess import QnnpackBackend
from executorch.backends.transforms.addmm_mm_to_linear import (
apply_addmm_mm_to_linear_transform,
)
from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from torch._export.pass_base import PassType
from torch.export import ExportedProgram
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

Expand Down Expand Up @@ -69,7 +66,7 @@ def __init__(
self,
delegate_name,
patterns,
transforms: Optional[List[Callable[[torch.fx.Graph], torch.fx.Graph]]] = None,
transforms: Optional[List[PassType]] = None,
):
"""
@param transforms: Optional list of transforms that will be applied to the graph before running the partitioner.
Expand Down Expand Up @@ -157,5 +154,5 @@ def __init__(self) -> None:
get_dynamic_quant_mm_without_view_copy_graph(dynamic_shape=True),
]
super().__init__(
QnnpackBackend.__name__, qnnp_patterns, [apply_addmm_mm_to_linear_transform]
QnnpackBackend.__name__, qnnp_patterns, [AddmmToLinearTransform()]
)
1 change: 0 additions & 1 deletion backends/qnnpack/serialization/qnnpack_graph_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

def convert_to_flatbuffer(qnn_dynamic_linear: QNNDynamicLinear) -> bytes:
qnnpack_graph_json = json.dumps(qnn_dynamic_linear, cls=_DataclassEncoder)

with tempfile.TemporaryDirectory() as d:
schema_path = os.path.join(d, "schema.fbs")
with open(schema_path, "wb") as schema_file:
Expand Down
15 changes: 7 additions & 8 deletions backends/qnnpack/test/test_qnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@

EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False)

# TODO(T158653285)
@unittest.expectedFailure

class TestQnnbackends(unittest.TestCase):
k_dim = 5
input_dims = (1, 4, k_dim)
Expand Down Expand Up @@ -89,7 +88,7 @@ def test_qnnpack_per_channel_dynamic_mm(self):
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
).check(
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
).check(
"executorch_exir_dialects_edge__ops_aten_mm"
).run(
Expand Down Expand Up @@ -170,7 +169,7 @@ def test_qnnpack_per_channel_dynamic_qlinear(self):
).check(
"aten_view_copy_default"
).check(
"aten_t_copy_default"
"aten_permute_copy_default"
).check(
"aten_addmm_default"
).check(
Expand Down Expand Up @@ -245,7 +244,7 @@ def test_qnnpack_per_tensor_dynamic_mm(self):
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
).check(
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
).check(
"executorch_exir_dialects_edge__ops_aten_mm"
).run(
Expand Down Expand Up @@ -326,7 +325,7 @@ def test_qnnpack_per_tensor_dynamic_qlinear(self):
).check(
"aten_view_copy_default"
).check(
"aten_t_copy_default"
"aten_permute_copy_default"
).check(
"aten_addmm_default"
).check(
Expand Down Expand Up @@ -400,7 +399,7 @@ def test_qnnpack_per_channel_dynamic_mm_with_dynamic_shape(self):
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
).check(
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
).check(
"executorch_exir_dialects_edge__ops_aten_mm"
).run(
Expand Down Expand Up @@ -482,7 +481,7 @@ def test_qnnpack_per_channel_dynamic_qlinear_via_partitioner(self):
).check(
"aten_view_copy_default"
).check(
"aten_t_copy_default"
"aten_permute_copy_default"
).check(
"aten_addmm_default"
).check(
Expand Down
2 changes: 0 additions & 2 deletions backends/qnnpack/test/test_qnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def get_actual_dyanmic_quantized_graph(
return dynamic_quantized_exir_graph.graph


# TODO(T158653285)
@unittest.expectedFailure
class TestQnnbackends(unittest.TestCase):
def test_dynamic_quantize_addmm_with_view_copy_partitioner(self):
example_inputs = (torch.rand(5, 1, 256),)
Expand Down
1 change: 1 addition & 0 deletions backends/transforms/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python_library(
srcs = ["addmm_mm_to_linear.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir:sym_util",
"//executorch/exir/dialects:lib",
],
Expand Down
17 changes: 15 additions & 2 deletions backends/transforms/addmm_mm_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from executorch.exir.sym_util import eval_shape

Expand Down Expand Up @@ -105,7 +106,10 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
with graph.inserting_after(node):
if node.target == ops.aten.addmm.default:
weight_t_node = node.args[2]
if weight_t_node.target != ops.aten.t_copy.default:
if weight_t_node.target not in [
ops.aten.t_copy.default,
ops.aten.permute_copy.default,
]:
raise RuntimeError(
f"Weight input to addmm must be tranposed but found {weight_t_node}"
)
Expand All @@ -120,7 +124,10 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
)
else:
weight_t_node = node.args[1]
if weight_t_node.target != ops.aten.t_copy.default:
if weight_t_node.target not in [
ops.aten.t_copy.default,
ops.aten.permute_copy.default,
]:
raise RuntimeError(
f"Weight input to addmm must be tranposed but found {weight_t_node}"
)
Expand All @@ -145,3 +152,9 @@ def apply_addmm_mm_to_linear_transform(graph: torch.fx.Graph) -> torch.fx.Graph:
graph = replace_addmm_mm_with_linear(graph)
graph = replace_linear_view_copy_input_output(graph)
return graph


class AddmmToLinearTransform(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
return PassResult(graph_module, True)
6 changes: 0 additions & 6 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,6 @@ def define_tensor(

# convert tensor shape must reflect memory format, default is contiguous, so
# only permute shape if we are converting the tensor to nhwc format
if tensor.target in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.t_copy.default,
):
# We ignore transpose nodes and reverse the dims to before it
dims = dims[::-1]
if swap_nc_for_depthwise_weights:
dims = [dims[1], dims[0]] + dims[2:]
if convert_to_nhwc:
Expand Down
11 changes: 1 addition & 10 deletions backends/xnnpack/operators/op_addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from executorch.backends.xnnpack.utils.xnnpack_constants import (
XNN_FLAG_TRANSPOSE_WEIGHTS,
)
from executorch.exir.dialects._ops import ops as exir_ops


@register_node_visitor
Expand Down Expand Up @@ -56,15 +55,7 @@ def define_node(
# output
output_id = vals_to_ids[node]

flag = (
0
if get_input_node(node, 2).target
in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.t_copy.default,
)
else XNN_FLAG_TRANSPOSE_WEIGHTS
)
flag = XNN_FLAG_TRANSPOSE_WEIGHTS

ser_node = XNode(
xnode_union=XNNFullyConnected(
Expand Down
3 changes: 3 additions & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
]

SUPPORTED_MODULES = [
Expand Down Expand Up @@ -95,7 +96,9 @@
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.t_copy.default,
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
]

SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {
Expand Down
Loading