Skip to content

Commit f93270a

Browse files
Arm backend: Add layer_norm decomposition (#6288)
Summary: Add decomposition for layer_norm, mean_dim and var_dim. Pull Request resolved: #6288 Reviewed By: cccclai Differential Revision: D64473056 Pulled By: digantdesai fbshipit-source-id: a3d4b0fbb0ea1d56b372b617f96dcbab41f72bb7
1 parent 46ea1a4 commit f93270a

15 files changed

+913
-94
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.decompose_layernorm_pass import (
23+
DecomposeLayerNormPass,
24+
)
25+
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
26+
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2227
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2328
InsertSqueezeAfterSumPass,
2429
)
@@ -53,7 +58,10 @@ def transform_to_backend_pipeline(
5358
self.add_pass(SizeAdjustConv2DPass())
5459
self.add_pass(RemoveClonePass())
5560
self.add_pass(ConvertExpandCopyToRepeatPass())
61+
self.add_pass(DecomposeLayerNormPass())
62+
self.add_pass(DecomposeVarPass())
5663
self.add_pass(ConvertMeanDimToAveragePool())
64+
self.add_pass(DecomposeMeanDimPass())
5765
self.add_pass(MatchArgRanksPass(exported_program))
5866
self.add_pass(DecomposeDivPass())
5967
self.add_pass(InsertSqueezeAfterSumPass())
@@ -67,6 +75,9 @@ def transform_to_backend_pipeline(
6775
return self._transform(exported_program.graph_module)
6876

6977
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
78+
self.add_pass(DecomposeLayerNormPass())
79+
self.add_pass(DecomposeVarPass())
80+
self.add_pass(DecomposeMeanDimPass())
7081
self.add_pass(ScalarsToAttributePass())
7182
self.add_pass(DecomposeDivPass())
7283
return self._transform(graph_module)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
def get_layer_norm_decomposition(op) -> tuple:
16+
if op == exir_ops.edge.aten.native_layer_norm.default:
17+
return (
18+
exir_ops.edge.aten.mean.dim,
19+
exir_ops.edge.aten.sub.Tensor,
20+
exir_ops.edge.aten.var.correction,
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.add.Tensor,
23+
exir_ops.edge.aten.rsqrt.default,
24+
exir_ops.edge.aten.mul.Tensor,
25+
exir_ops.edge.aten.view_copy.default,
26+
)
27+
if op == torch.ops.aten.layer_norm.default:
28+
return (
29+
torch.ops.aten.mean.dim,
30+
torch.ops.aten.sub.Tensor,
31+
torch.ops.aten.var.correction,
32+
torch.ops.aten.full.default,
33+
torch.ops.aten.add.Tensor,
34+
torch.ops.aten.rsqrt.default,
35+
torch.ops.aten.mul.Tensor,
36+
torch.ops.aten.view_copy.default,
37+
)
38+
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
39+
40+
41+
class DecomposeLayerNormPass(ExportPass):
42+
"""
43+
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
44+
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
45+
mean = op_mean(x, dims) # E[x]
46+
var = op_var(x, dims) # Var[x]
47+
denominator = op_sub(x, mean) # (x - E[x])
48+
add = op_add(var, eps) # Var[x] + eps
49+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
50+
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
51+
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
53+
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
54+
"""
55+
56+
def call(self, graph_module: torch.fx.GraphModule):
57+
for node in graph_module.graph.nodes:
58+
if node.op != "call_function" or node.target not in (
59+
exir_ops.edge.aten.native_layer_norm.default,
60+
torch.ops.aten.layer_norm.default,
61+
):
62+
continue
63+
64+
# epsilon default value
65+
epsilon = torch.finfo().eps
66+
weights = None
67+
bias = None
68+
args = node.args
69+
meta = node.meta
70+
match len(args):
71+
case 5:
72+
x, normalized_shape, weights, bias, epsilon = args
73+
case 4:
74+
x, normalized_shape, weights, bias = args
75+
case 3:
76+
x, normalized_shape, weights = args
77+
case _:
78+
x, normalized_shape = args
79+
80+
n_dims = len(normalized_shape)
81+
if isinstance(meta["val"], tuple):
82+
shape = meta["val"][0].size()
83+
else:
84+
shape = meta["val"].size()
85+
dtype = meta["val"][0].dtype
86+
rank = len(shape)
87+
dims = list(range(-1, -1 * (n_dims + 1), -1))
88+
dims = [dim % rank for dim in dims]
89+
weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)]
90+
epsilon_reshaped_shape = [1] * rank
91+
92+
(
93+
mean_op,
94+
sub_op,
95+
var_op,
96+
full_op,
97+
add_op,
98+
rsqrt_op,
99+
mul_op,
100+
view_op,
101+
) = get_layer_norm_decomposition(node.target)
102+
with graph_module.graph.inserting_before(node):
103+
keepdim = True
104+
mean = create_node(graph_module.graph, mean_op, args=(x, dims, keepdim))
105+
sub = create_node(graph_module.graph, sub_op, args=(x, mean))
106+
var = create_node(
107+
graph_module.graph,
108+
var_op,
109+
args=(x, dims),
110+
kwargs={"correction": 0, "keepdim": keepdim},
111+
)
112+
full = create_node(
113+
graph_module.graph,
114+
full_op,
115+
args=(epsilon_reshaped_shape, epsilon),
116+
kwargs={"dtype": dtype},
117+
)
118+
add0 = create_node(graph_module.graph, add_op, args=(var, full))
119+
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
120+
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
121+
if weights is not None:
122+
weights_reshaped = create_node(
123+
graph_module.graph,
124+
view_op,
125+
args=(weights, weights_reshaped_shape),
126+
)
127+
mul1 = create_node(
128+
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
129+
)
130+
else:
131+
mul1 = mul0
132+
output = mul1
133+
if bias is not None:
134+
bias_reshaped_shape = weights_reshaped_shape
135+
bias_reshaped = create_node(
136+
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
137+
)
138+
output = create_node(
139+
graph_module.graph, add_op, args=(mul1, bias_reshaped)
140+
)
141+
142+
users = [user for user in node.users if node != user]
143+
node.replace_all_uses_with(output)
144+
for user in users:
145+
if user.target == operator.getitem:
146+
user.replace_all_uses_with(output)
147+
graph_module.graph.erase_node(node)
148+
graph_module.graph.eliminate_dead_code()
149+
graph_module.recompile()
150+
graph_module = super().call(graph_module).graph_module
151+
152+
return PassResult(graph_module, True)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_meandim_decomposition(op) -> tuple:
13+
if op == exir_ops.edge.aten.mean.dim:
14+
return (
15+
exir_ops.edge.aten.sum.dim_IntList,
16+
exir_ops.edge.aten.full.default,
17+
exir_ops.edge.aten.mul.Tensor,
18+
)
19+
if op == torch.ops.aten.mean.dim:
20+
return (
21+
torch.ops.aten.sum.dim_IntList,
22+
torch.ops.aten.full.default,
23+
torch.ops.aten.mul.Tensor,
24+
)
25+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
26+
27+
28+
class DecomposeMeanDimPass(ExportPass):
29+
"""
30+
This pass decomposes meandim into a sum and mul node.
31+
32+
Example:
33+
y = mean_dim(x, dim, keepdim)
34+
Becomes:
35+
sum = sum.dim_IntList(x, dim, keepdim)
36+
y = mul(sum, 1/N)
37+
"""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
x = args[0]
44+
dim = args[1]
45+
keepdim = args[2] if len(args) > 2 else False
46+
if not keepdim:
47+
return super().call_operator(op, args, kwargs, meta)
48+
# if keepdim == True and dim == [-1, -2], mean.dim can be
49+
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
50+
if dim == [-1, -2]:
51+
# Simply return the mean.dim operator for future decomposition.
52+
return super().call_operator(op, args, kwargs, meta)
53+
shape = meta["val"].size()
54+
dtype = meta["val"].dtype
55+
input_shape = x.data.size()
56+
N = 1
57+
for d in dim:
58+
N *= input_shape[d]
59+
60+
sum_op, full_op, mul_op = get_meandim_decomposition(op)
61+
62+
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
63+
full = super().call_operator(
64+
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
65+
)
66+
return super().call_operator(mul_op, (sum, full), {}, meta)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
def get_var_decomposition(op) -> tuple:
14+
if op == exir_ops.edge.aten.var.correction:
15+
return (
16+
exir_ops.edge.aten.mean.dim,
17+
exir_ops.edge.aten.sub.Tensor,
18+
exir_ops.edge.aten.mul.Tensor,
19+
exir_ops.edge.aten.sum.dim_IntList,
20+
exir_ops.edge.aten.full.default,
21+
)
22+
if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim):
23+
return (
24+
torch.ops.aten.mean.dim,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.mul.Tensor,
27+
torch.ops.aten.sum.dim_IntList,
28+
torch.ops.aten.full,
29+
)
30+
raise RuntimeError(f"Can't get var decomposition for op {op}")
31+
32+
33+
class DecomposeVarPass(ExportPass):
34+
"""
35+
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
36+
37+
Example:
38+
y = var_correction(x, dim, keepdim, correction)
39+
Becomes:
40+
mean = mean(x, dim)
41+
diff = sub(x, mean)
42+
squared_diff = mul(diff, diff)
43+
sum = sum(squared_diff, dim)
44+
y = div(sum, max(0, N-correction))
45+
"""
46+
47+
def call_operator(self, op, args, kwargs, meta):
48+
if op not in (
49+
exir_ops.edge.aten.var.correction,
50+
torch.ops.aten.var.correction,
51+
torch.ops.aten.var.dim,
52+
):
53+
return super().call_operator(op, args, kwargs, meta)
54+
shape = meta["val"].size()
55+
dtype = meta["val"].dtype
56+
dim = args[1] if len(args) > 1 else list(range(len(shape)))
57+
if op == torch.ops.aten.var.dim:
58+
correction = args[-2]
59+
keepdim = args[-1]
60+
else:
61+
correction = kwargs["correction"]
62+
keepdim = kwargs.get("keepdim", False)
63+
if not keepdim:
64+
return super().call_operator(op, args, kwargs, meta)
65+
66+
x = args[0]
67+
input_shape = x.data.size()
68+
N = 1
69+
for d in dim:
70+
N *= input_shape[d]
71+
72+
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
73+
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
74+
diff = super().call_operator(diff_op, (x, mean), {}, meta)
75+
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
76+
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
77+
full = super().call_operator(
78+
full_op,
79+
([1 for _ in shape], 1 / max(0, N - correction)),
80+
{"dtype": dtype},
81+
meta,
82+
)
83+
return super().call_operator(mul_op, (sum, full), {}, meta)

backends/arm/arm_partitioner.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import operator
1010
import os
11-
from typing import final, List
11+
from typing import cast, final, List
1212

1313
import torch
1414
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
@@ -53,6 +53,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.mul.Tensor,
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
56+
exir_ops.edge.aten.native_layer_norm.default,
5657
exir_ops.edge.aten.avg_pool2d.default,
5758
exir_ops.edge.aten.sigmoid.default,
5859
exir_ops.edge.aten.mm.default,
@@ -67,6 +68,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6768
exir_ops.edge.aten.view_copy.default,
6869
exir_ops.edge.aten.clone.default,
6970
exir_ops.edge.aten.mean.dim,
71+
exir_ops.edge.aten.var.correction,
7072
exir_ops.edge.aten.unsqueeze_copy.default,
7173
exir_ops.edge.aten.squeeze_copy.dims,
7274
operator.getitem,
@@ -85,10 +87,11 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
8587

8688
def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
8789
if node.target == exir_ops.edge.aten.mean.dim:
88-
dim = node.args[1]
89-
keep_dim = node.args[2]
90-
if dim != [-1, -2] or keep_dim is False:
91-
return False
90+
keep_dim = node.args[2] if len(node.args) > 2 else False
91+
return cast(bool, keep_dim)
92+
if node.target == exir_ops.edge.aten.var.correction:
93+
keep_dim = node.kwargs.get("keepdim", False)
94+
return cast(bool, keep_dim)
9295
return True
9396

9497

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
op_get_item,
2121
op_hardtanh,
2222
op_log,
23-
op_mean_dim,
2423
op_mm,
2524
op_mul,
2625
op_permute,

backends/arm/operators/op_full.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,7 @@ def define_node(
5252
dtype = ts.DType.FP32
5353
data = np.full(shape, value, dtype=np.float32)
5454

55-
tosa_graph.addConst(shape, dtype, data, "full-const")
56-
tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name])
55+
tosa_graph.addConst(shape, dtype, data, node.name + "full-const")
56+
tosa_graph.addOperator(
57+
ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name]
58+
)

0 commit comments

Comments
 (0)