Skip to content

Commit 6a63b70

Browse files
mcr229facebook-github-bot
authored andcommitted
Fixing xnnpack+qnnpack test (#355)
Summary: Fixing all qnnpack tests. In addition to the broken test linked by mergen, there were many tests set to expect to fail. This diff fixes all tests. Reviewed By: digantdesai, kirklandsign Differential Revision: D49068312
1 parent 7b29899 commit 6a63b70

File tree

8 files changed

+32
-25
lines changed

8 files changed

+32
-25
lines changed

backends/qnnpack/QNNPackBackend.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ class QnnpackBackend final : public PyTorchBackendInterface {
140140
weights_zp->buffer()->data(),
141141
ScalarType::QUInt8,
142142
runtime_allocator,
143-
0,
143+
pre_pad_bytes, // Not necessary to prepad but surpresses asan errors:
144+
// D42179009
144145
&zp_buf);
145146

146147
// Create + copy Weight Scales Tensor
@@ -152,7 +153,8 @@ class QnnpackBackend final : public PyTorchBackendInterface {
152153
weights_scale->buffer()->data(),
153154
ScalarType::Float,
154155
runtime_allocator,
155-
0,
156+
pre_pad_bytes, // Not necessary to prepad but surpresses asan errors:
157+
// D42179009
156158
&scale_buf);
157159

158160
// Create Quantized Input Tensor

backends/qnnpack/partition/qnnpack_partitioner.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Callable, Dict, List, Optional, Union
9-
10-
import torch
8+
from typing import Dict, List, Optional, Union
119

1210
from executorch.backends.qnnpack.partition.support_patterns import (
1311
get_dynamic_quant_addmm_with_view_copy_graph,
@@ -16,14 +14,13 @@
1614
get_dynamic_quant_mm_without_view_copy_graph,
1715
)
1816
from executorch.backends.qnnpack.qnnpack_preprocess import QnnpackBackend
19-
from executorch.backends.transforms.addmm_mm_to_linear import (
20-
apply_addmm_mm_to_linear_transform,
21-
)
17+
from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
2218
from executorch.exir.backend.partitioner import (
2319
DelegationSpec,
2420
Partitioner,
2521
PartitionResult,
2622
)
23+
from torch._export.pass_base import PassType
2724
from torch.export import ExportedProgram
2825
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
2926

@@ -69,7 +66,7 @@ def __init__(
6966
self,
7067
delegate_name,
7168
patterns,
72-
transforms: Optional[List[Callable[[torch.fx.Graph], torch.fx.Graph]]] = None,
69+
transforms: Optional[List[PassType]] = None,
7370
):
7471
"""
7572
@param transforms: Optional list of transforms that will be applied to the graph before running the partitioner.
@@ -157,5 +154,5 @@ def __init__(self) -> None:
157154
get_dynamic_quant_mm_without_view_copy_graph(dynamic_shape=True),
158155
]
159156
super().__init__(
160-
QnnpackBackend.__name__, qnnp_patterns, [apply_addmm_mm_to_linear_transform]
157+
QnnpackBackend.__name__, qnnp_patterns, [AddmmToLinearTransform()]
161158
)

backends/qnnpack/serialization/qnnpack_graph_serialize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
def convert_to_flatbuffer(qnn_dynamic_linear: QNNDynamicLinear) -> bytes:
2323
qnnpack_graph_json = json.dumps(qnn_dynamic_linear, cls=_DataclassEncoder)
24-
2524
with tempfile.TemporaryDirectory() as d:
2625
schema_path = os.path.join(d, "schema.fbs")
2726
with open(schema_path, "wb") as schema_file:

backends/qnnpack/test/test_qnnpack.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@
5151

5252
EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False)
5353

54-
# TODO(T158653285)
55-
@unittest.expectedFailure
54+
5655
class TestQnnbackends(unittest.TestCase):
5756
k_dim = 5
5857
input_dims = (1, 4, k_dim)
@@ -89,7 +88,7 @@ def test_qnnpack_per_channel_dynamic_mm(self):
8988
).check(
9089
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
9190
).check(
92-
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
91+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
9392
).check(
9493
"executorch_exir_dialects_edge__ops_aten_mm"
9594
).run(
@@ -170,7 +169,7 @@ def test_qnnpack_per_channel_dynamic_qlinear(self):
170169
).check(
171170
"aten_view_copy_default"
172171
).check(
173-
"aten_t_copy_default"
172+
"aten_permute_copy_default"
174173
).check(
175174
"aten_addmm_default"
176175
).check(
@@ -245,7 +244,7 @@ def test_qnnpack_per_tensor_dynamic_mm(self):
245244
).check(
246245
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
247246
).check(
248-
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
247+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
249248
).check(
250249
"executorch_exir_dialects_edge__ops_aten_mm"
251250
).run(
@@ -326,7 +325,7 @@ def test_qnnpack_per_tensor_dynamic_qlinear(self):
326325
).check(
327326
"aten_view_copy_default"
328327
).check(
329-
"aten_t_copy_default"
328+
"aten_permute_copy_default"
330329
).check(
331330
"aten_addmm_default"
332331
).check(
@@ -400,7 +399,7 @@ def test_qnnpack_per_channel_dynamic_mm_with_dynamic_shape(self):
400399
).check(
401400
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
402401
).check(
403-
"executorch_exir_dialects_edge__ops_aten_t_copy_default"
402+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default"
404403
).check(
405404
"executorch_exir_dialects_edge__ops_aten_mm"
406405
).run(
@@ -482,7 +481,7 @@ def test_qnnpack_per_channel_dynamic_qlinear_via_partitioner(self):
482481
).check(
483482
"aten_view_copy_default"
484483
).check(
485-
"aten_t_copy_default"
484+
"aten_permute_copy_default"
486485
).check(
487486
"aten_addmm_default"
488487
).check(

backends/qnnpack/test/test_qnnpack_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def get_actual_dyanmic_quantized_graph(
6666
return dynamic_quantized_exir_graph.graph
6767

6868

69-
# TODO(T158653285)
70-
@unittest.expectedFailure
7169
class TestQnnbackends(unittest.TestCase):
7270
def test_dynamic_quantize_addmm_with_view_copy_partitioner(self):
7371
example_inputs = (torch.rand(5, 1, 256),)

backends/transforms/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python_library(
1919
srcs = ["addmm_mm_to_linear.py"],
2020
deps = [
2121
"//caffe2:torch",
22+
"//executorch/exir:pass_base",
2223
"//executorch/exir:sym_util",
2324
"//executorch/exir/dialects:lib",
2425
],

backends/transforms/addmm_mm_to_linear.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
910

1011
from executorch.exir.sym_util import eval_shape
1112

@@ -105,7 +106,10 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
105106
with graph.inserting_after(node):
106107
if node.target == ops.aten.addmm.default:
107108
weight_t_node = node.args[2]
108-
if weight_t_node.target != ops.aten.t_copy.default:
109+
if weight_t_node.target not in [
110+
ops.aten.t_copy.default,
111+
ops.aten.permute_copy.default,
112+
]:
109113
raise RuntimeError(
110114
f"Weight input to addmm must be tranposed but found {weight_t_node}"
111115
)
@@ -120,7 +124,10 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
120124
)
121125
else:
122126
weight_t_node = node.args[1]
123-
if weight_t_node.target != ops.aten.t_copy.default:
127+
if weight_t_node.target not in [
128+
ops.aten.t_copy.default,
129+
ops.aten.permute_copy.default,
130+
]:
124131
raise RuntimeError(
125132
f"Weight input to addmm must be tranposed but found {weight_t_node}"
126133
)
@@ -145,3 +152,9 @@ def apply_addmm_mm_to_linear_transform(graph: torch.fx.Graph) -> torch.fx.Graph:
145152
graph = replace_addmm_mm_with_linear(graph)
146153
graph = replace_linear_view_copy_input_output(graph)
147154
return graph
155+
156+
157+
class AddmmToLinearTransform(ExportPass):
158+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
159+
graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
160+
return PassResult(graph_module, True)

exir/tests/test_memory_planning.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,6 @@ def quantize(self, eager_model: nn.Module) -> nn.Module:
457457
)
458458
return quantized_model
459459

460-
# TODO(T158653285)
461-
@unittest.expectedFailure
462460
def test_asr_joiner(self) -> None:
463461
eager_model = self.quantize(ASRJoiner())
464462
inputs = eager_model.get_random_inputs()

0 commit comments

Comments
 (0)