Skip to content

Commit 9dece67

Browse files
authored
Arm backend: Add DecomposeLinalgVectorNorm pass + tests (#10848)
Added decomposition of linalg vector norm. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent b058afb commit 9dece67

File tree

6 files changed

+306
-0
lines changed

6 files changed

+306
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2525
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2626
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
27+
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
2728
from .decompose_linear_pass import DecomposeLinearPass # noqa
2829
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2930
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DecomposeLayerNormPass,
3030
DecomposeLeakyReLUPass,
3131
DecomposeLinearPass,
32+
DecomposeLinearVectorNormPass,
3233
DecomposeMeanDimPass,
3334
DecomposeNotEqualPass,
3435
DecomposeSelectPass,
@@ -86,6 +87,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8687
self.add_pass(ConvertSplitToSlicePass())
8788
self.add_pass(ConvertMmToBmmPass())
8889
self.add_pass(DecomposeLinearPass())
90+
self.add_pass(DecomposeLinearVectorNormPass())
8991
self.add_pass(DecomposeMeanDimPass())
9092
self.add_pass(ConvertFullLikeToFullPass())
9193
self.add_pass(ConvertToClampPass())
@@ -133,6 +135,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133135
self.add_pass(FuseBatchnorm2DPass(exported_program))
134136
self.add_pass(ConvertMmToBmmPass())
135137
self.add_pass(DecomposeLinearPass())
138+
self.add_pass(DecomposeLinearVectorNormPass())
136139
self.add_pass(DecomposeLeakyReLUPass())
137140
self.add_pass(DecomposeBatchNormPass())
138141
self.add_pass(DecomposeLayerNormPass())
@@ -207,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207210
self.add_pass(DecomposeCosineSimilarityPass())
208211
self.add_pass(DecomposeDivPass())
209212
self.add_pass(DecomposeLeakyReLUPass())
213+
self.add_pass(DecomposeLinearVectorNormPass())
210214
self.add_pass(DecomposeSqrtPass())
211215
self.add_pass(DecomposeSiluPass())
212216

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class DecomposeLinearVectorNormPass(ExportPass):
11+
"""
12+
This pass decomposes aten.linalg_vector_norm.default into more primitive ops.
13+
We need to add this pass before quantization for graph annotation.
14+
By default, aten.linalg_vector_norm op is decomposed during legalization to Edge IR.
15+
16+
The decomposition is as follows:
17+
18+
For p == 1:
19+
out = REDUCE_SUM(ABS(x), dims, keepdim)
20+
21+
For p == 2:
22+
out = SQRT(REDUCE_SUM(MUL(x, x), dims, keepdim))
23+
24+
For arbitrary p:
25+
We dont support arbitrary p, because our decomposition looks like
26+
out = POW(REDUCE_SUM(POW(ABS(x), p), dims, keepdim), 1/p)
27+
In this case we need to wrap p into Tensor and we need to know
28+
dtype prior, but we dont know this from FX graph.
29+
"""
30+
31+
torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,)
32+
33+
def call_operator(self, op, args, kwargs, meta):
34+
if op not in self.torch_linalg_vector_norm:
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
# Extract inputs and optional arguments.
38+
# Expected args:
39+
# args[0]: input tensor
40+
# args[1]: norm order 'p' (optional, default: 2.0)
41+
# args[2]: dimensions to reduce (should be provided)
42+
# args[3]: keepdim flag (optional, default: False)
43+
input_tensor = args[0]
44+
norm_order = args[1] if len(args) > 1 else 2.0
45+
norm_dim = args[2] if len(args) > 2 else None
46+
keepdim = args[3] if len(args) > 3 else False
47+
48+
if norm_order not in (1, 2):
49+
raise ValueError(
50+
f"The order of {norm_order}\n"
51+
f"is not supported for linalg_vector_norm operator"
52+
)
53+
54+
if norm_dim is None:
55+
raise ValueError("The norm_dim for linalg_vector_norm is None.")
56+
57+
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
58+
59+
# Decomposition based on norm order.
60+
if norm_order == 1:
61+
op1 = super().call_operator(
62+
torch.ops.aten.abs.default, (input_tensor,), {}, meta
63+
)
64+
op2 = super().call_operator(
65+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
66+
)
67+
return op2
68+
69+
elif norm_order == 2:
70+
# For p == 2, decomposition is sqrt(sum(x * x, dims, keepdim))
71+
op1 = super().call_operator(
72+
torch.ops.aten.mul.Tensor, (input_tensor, input_tensor), {}, meta
73+
)
74+
op2 = super().call_operator(
75+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
76+
)
77+
op3 = super().call_operator(torch.ops.aten.sqrt.default, (op2,), {}, meta)
78+
return op3

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CUSTOM_EDGE_OPS = [
99
"linspace.default",
1010
"eye.default",
11+
"vector_norm.default",
1112
"hardsigmoid.default",
1213
"hardswish.default",
1314
"linear.default",
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
input_t = Tuple[torch.Tensor]
19+
20+
aten_op_q_decomposed_q = "torch.ops.quantized_decomposed.quantize_per_tensor.default"
21+
exir_op_q_decomposed = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
22+
23+
24+
class VectorNormModel(torch.nn.Module):
25+
def __init__(
26+
self,
27+
ord=None,
28+
dim=1,
29+
keepdim=False,
30+
):
31+
"""
32+
A simple module that applies torch.linalg.vector_norm to its input.
33+
Ord is 2 by default.
34+
"""
35+
super().__init__()
36+
self.ord = ord
37+
self.dim = dim
38+
self.keepdim = keepdim
39+
40+
def forward(self, x: torch.Tensor) -> torch.Tensor:
41+
if self.ord is None and self.dim is None:
42+
return torch.linalg.vector_norm(x, keepdim=self.keepdim)
43+
elif self.ord is None:
44+
return torch.linalg.vector_norm(x, dim=self.dim, keepdim=self.keepdim)
45+
elif self.dim is None:
46+
return torch.linalg.vector_norm(x, ord=self.ord, keepdim=self.keepdim)
47+
else:
48+
return torch.linalg.vector_norm(
49+
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
50+
)
51+
52+
53+
test_modules = {
54+
"default": (VectorNormModel(dim=1), (torch.rand(10, 4),)),
55+
"ord1": (VectorNormModel(ord=1, dim=1), (torch.rand(10, 4),)),
56+
"ord2": (VectorNormModel(ord=2, dim=1), (torch.rand(10, 20),)),
57+
# Norm computed along a specific dimension of a 3D tensor
58+
"dim_3d": (VectorNormModel(dim=2), (torch.rand(4, 5, 6),)),
59+
}
60+
61+
62+
@common.parametrize("test_module", test_modules)
63+
def test_vector_norm_tosa_MI(test_module):
64+
model, input_tensor = test_module
65+
66+
# We decompose LinalgVectorNorm before quantize stage to have annotations
67+
# with q/dq nodes. In case of MI, this operator will be decomposed
68+
# by global decompositions.
69+
aten_op = "torch.ops.aten.linalg_vector_norm.default"
70+
# Should not found this op
71+
exir_op = "executorch_exir_dialects_edge__ops_aten_linalg_vector_norm_default"
72+
73+
pipeline = TosaPipelineMI[input_t](model, input_tensor, aten_op, exir_op)
74+
75+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1e-4, rtol=1e-4)
76+
pipeline.run()
77+
78+
79+
@common.parametrize("test_module", test_modules)
80+
def test_vector_norm_tosa_BI(test_module):
81+
model, input_tensor = test_module
82+
83+
# Should not found this op
84+
exir_op = "executorch_exir_dialects_edge__ops_aten_linalg_vector_norm_default"
85+
86+
pipeline = TosaPipelineBI[input_t](
87+
model,
88+
input_tensor,
89+
aten_op_q_decomposed_q,
90+
exir_op,
91+
symmetric_io_quantization=True,
92+
)
93+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
94+
pipeline.run()
95+
96+
97+
@common.parametrize("test_module", test_modules)
98+
@common.XfailIfNoCorstone300
99+
def test_vector_norm_u55_BI_fvp(test_module):
100+
model, input_tensor = test_module
101+
102+
pipeline = EthosU55PipelineBI[input_t](
103+
model,
104+
input_tensor,
105+
aten_op_q_decomposed_q,
106+
exir_op_q_decomposed,
107+
run_on_fvp=True,
108+
symmetric_io_quantization=True,
109+
)
110+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
111+
pipeline.pop_stage("check_not.exir")
112+
pipeline.run()
113+
114+
115+
@common.parametrize("test_module", test_modules)
116+
@common.XfailIfNoCorstone300
117+
def test_vector_norm_u85_BI_fvp(test_module):
118+
model, input_tensor = test_module
119+
120+
# The should be decomposed and annotated in DecomposeLinalgVectorNorm pass.
121+
pipeline = EthosU85PipelineBI[input_t](
122+
model,
123+
input_tensor,
124+
aten_op_q_decomposed_q,
125+
exir_op_q_decomposed,
126+
run_on_fvp=True,
127+
symmetric_io_quantization=True,
128+
)
129+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
130+
pipeline.pop_stage("check_not.exir")
131+
pipeline.run()
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.decompose_linalg_vector_norm_pass import (
11+
DecomposeLinearVectorNormPass,
12+
)
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
16+
input_t = Tuple[torch.Tensor]
17+
18+
19+
class VectorNormModel(torch.nn.Module):
20+
"""
21+
A test module with torch.linalg.vector_norm.
22+
https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
23+
24+
We support only order 1 or 2.
25+
"""
26+
27+
def __init__(self, ord: float = None, dim=None, keepdim: bool = False):
28+
super().__init__()
29+
self.ord = ord
30+
self.dim = dim
31+
self.keepdim = keepdim
32+
33+
def forward(self, x: torch.Tensor) -> torch.Tensor:
34+
if self.ord is None and self.dim is None:
35+
return torch.linalg.vector_norm(x, keepdim=self.keepdim)
36+
elif self.ord is None:
37+
return torch.linalg.vector_norm(x, dim=self.dim, keepdim=self.keepdim)
38+
elif self.dim is None:
39+
return torch.linalg.vector_norm(x, ord=self.ord, keepdim=self.keepdim)
40+
else:
41+
return torch.linalg.vector_norm(
42+
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
43+
)
44+
45+
def get_inputs(self) -> input_t:
46+
return (torch.rand(4, 4),)
47+
48+
49+
modules = {
50+
# Default uses p=2 (l2 vector norm)
51+
"default_p2": VectorNormModel(dim=1),
52+
# p = 1: L1 norm over all elements
53+
"p1": VectorNormModel(ord=1, dim=1),
54+
}
55+
56+
57+
@common.parametrize("module", modules)
58+
def test_decompose_vector_norm_tosa_BI(module):
59+
"""
60+
This test creates a PassPipeline that applies the DecomposeLinearVectorNormPass.
61+
The expected primitive ops vary depending on the norm order:
62+
- p == 1: should decompose to ABS and SUM.
63+
- p == 2 (default): should decompose to MUL, SUM, and SQRT.
64+
- Other p: should decompose to ABS, two instances of POW, and SUM.
65+
"""
66+
ord_val = module.ord if module.ord is not None else 2.0
67+
68+
if ord_val == 1:
69+
ops_after_pass = {
70+
"executorch_exir_dialects_edge__ops_aten_abs_default": 1,
71+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1,
72+
}
73+
elif ord_val == 2:
74+
ops_after_pass = {
75+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2,
76+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1,
77+
}
78+
79+
pipeline = PassPipeline[input_t](
80+
module,
81+
module.get_inputs(),
82+
# The op is decomposed in legalization aten -> edge, so we are not able to check ops before
83+
ops_before_pass=None,
84+
ops_not_before_pass=None,
85+
ops_after_pass=ops_after_pass,
86+
ops_not_after_pass=[
87+
"executorch_exir_dialects_edge__ops_aten_linarg_vector_norm_default",
88+
],
89+
pass_list=[DecomposeLinearVectorNormPass],
90+
)
91+
pipeline.run()

0 commit comments

Comments
 (0)