Skip to content

Commit 84fa013

Browse files
committed
Update base for Update on "[ET-VK] Simplifying conv1d op shader by changing it to process one output texel per thread."
This diff changes conv1d shader to process one output texel per thread, increasing GPU occupancy and improve performance. Differential Revision: [D74097560](https://our.internmc.facebook.com/intern/diff/D74097560/) [ghstack-poisoned]
2 parents 0c03969 + cd3b53d commit 84fa013

File tree

23 files changed

+426
-236
lines changed

23 files changed

+426
-236
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
)
6060

6161
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.transforms.decompose_sdpa import (
63+
DecomposeScaledDotProductAttention,
64+
)
6265
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
6366
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6467
from executorch.exir import ExportedProgram
@@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
194197
)
195198

196199
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200+
self.add_pass(DecomposeScaledDotProductAttention())
197201
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
198202
self.add_pass(ScalarsToAttributePass())
199203
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from executorch.exir.pass_base import ExportPass
99

1010
# For BI case
11-
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
11+
torch_softmax = (
12+
torch.ops.aten.softmax.int,
13+
torch.ops.aten._safe_softmax.default,
14+
torch.ops.aten.log_softmax.int,
15+
)
1216
# For MI case
1317
edge_softmax = (
1418
exir_ops.edge.aten._softmax.default,

backends/arm/test/models/test_conformer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self):
8383
)
8484
)
8585

86-
@unittest.expectedFailure # TODO(MLETORCH-635)
8786
def test_conformer_u55_BI(self):
8887
tester = (
8988
ArmTester(
@@ -97,13 +96,20 @@ def test_conformer_u55_BI(self):
9796
.to_executorch()
9897
.serialize()
9998
)
99+
100100
if conftest.is_option_enabled("corstone_fvp"):
101-
tester.run_method_and_compare_outputs(
102-
qtol=1.0,
103-
rtol=1.0,
104-
atol=5.0,
105-
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
106-
)
101+
try:
102+
tester.run_method_and_compare_outputs(
103+
qtol=1.0,
104+
rtol=1.0,
105+
atol=5.0,
106+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
107+
)
108+
self.fail(
109+
"TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
110+
)
111+
except Exception:
112+
pass
107113

108114
@unittest.expectedFailure # TODO(MLETORCH-635)
109115
def test_conformer_u85_BI(self):

backends/arm/test/ops/test_sdpa.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
17+
class SDPA(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, query, key, value):
22+
return torch.nn.functional.scaled_dot_product_attention(
23+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
24+
)
25+
26+
27+
input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
28+
29+
30+
def test_sdpa_MI():
31+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
32+
pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], [])
33+
pipeline.pop_stage("check_count.exir")
34+
pipeline.run()
35+
36+
37+
def test_sdpa_BI():
38+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
39+
pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], [])
40+
pipeline.pop_stage("check.quant_nodes")
41+
pipeline.pop_stage("check_count.exir")
42+
pipeline.pop_stage(
43+
"run_method_and_compare_outputs"
44+
) # TODO: reference is not quantized
45+
pipeline.run()

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ python_unittest(
347347
":compiler",
348348
"//caffe2:torch",
349349
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:graph_builder",
350351
"//executorch/backends/cadence/aot:ops_registrations",
351352
"//executorch/backends/cadence/aot:pass_utils",
352353
"//executorch/backends/cadence/aot:remove_ops",

backends/cadence/aot/compiler.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def quantize_pt2(
151151
quantizer: Optional[CadenceQuantizer] = None,
152152
calibration_data: Optional[list[tuple[object, ...]]] = None,
153153
dump_graphs: bool = False,
154-
) -> torch.fx.GraphModule:
154+
) -> ExportedProgram:
155155
"""
156156
Trace, prepare, convert and fuse the model using the given quantizer.
157157
If calibration data is provided, it will be used to calibrate the model. If
@@ -178,7 +178,9 @@ def quantize_pt2(
178178
logging.info("Graph after quantization and fusion:")
179179
logging.info(fused_gm.graph.print_tabular())
180180

181-
return fused_gm
181+
program = torch.export.export(fused_gm, inputs, strict=True)
182+
183+
return program
182184

183185

184186
# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
260262
dump_graphs: bool = False,
261263
constant_methods: Optional[dict[str, object]] = None,
262264
) -> EdgeProgramManager:
265+
"""
266+
Trace, quantize and lower a model/inputs pair to edge IR.
267+
"""
263268
quantized_model = quantize_pt2(
264269
model,
265270
inputs,
266271
quantizer=quantizer,
267272
dump_graphs=dump_graphs,
268273
)
269274

270-
return export_to_edge(
275+
return lower_ep_to_edge(
271276
quantized_model,
272-
inputs,
273277
dump_graphs=dump_graphs,
274278
constant_methods=constant_methods,
275279
)
276280

277281

282+
def lower_ep_to_cadence(
283+
program: ExportedProgram,
284+
dump_graphs: bool = False,
285+
opt_level: int = 1,
286+
) -> EdgeProgramManager:
287+
"""
288+
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
289+
"""
290+
edge_prog_manager = lower_ep_to_edge(program, dump_graphs=dump_graphs)
291+
cadence_passes = get_cadence_passes(opt_level)
292+
293+
# Run a couple required passes for quant/dequant ops
294+
cadence_prog_manager = edge_prog_manager.transform(
295+
cast(
296+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
297+
)
298+
)
299+
return cadence_prog_manager
300+
301+
278302
def export_to_cadence(
279303
model: torch.nn.Module,
280304
inputs: tuple[object, ...],
@@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
299323
dump_graphs: bool = False,
300324
opt_level: int = 1,
301325
) -> EdgeProgramManager:
326+
"""
327+
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
328+
optimization passes.
329+
"""
302330
quantized_model = quantize_pt2(model, inputs)
303331

304-
return export_to_cadence(
332+
return lower_ep_to_cadence(
305333
quantized_model,
306-
inputs,
307334
opt_level=opt_level,
308335
dump_graphs=dump_graphs,
309336
)

0 commit comments

Comments
 (0)