Skip to content

Commit d1913ae

Browse files
committed
Create NeutronAtenPassManager with initial BatchNorm fusing passes
1 parent 12ed924 commit d1913ae

File tree

4 files changed

+546
-0
lines changed

4 files changed

+546
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Optional
6+
7+
import torch
8+
from torch.export.unflatten import _assign_attr, _AttrKind
9+
from torch.fx import GraphModule, Node
10+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
11+
from torch.nn.parameter import Parameter
12+
from torch.nn.utils import fuse_conv_bn_weights
13+
14+
15+
class FuseBatchNormWithConvPass(PassBase):
16+
"""The executorch batch normalization carries out the following computation [1].
17+
18+
(x - mean) / (var + eps) * W + B
19+
20+
Which can be expressed as
21+
22+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
23+
24+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
25+
and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and
26+
bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be
27+
completely removed.
28+
29+
30+
31+
┌─────────────▼─────────────┐
32+
│ aten.conv1d | aten.conv2d │
33+
└─────────────┬─────────────┘
34+
│ │
35+
┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐
36+
│ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │
37+
└─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘
38+
│ ▼
39+
40+
41+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
42+
"""
43+
44+
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
45+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
46+
if node is None or node.op != "get_attr":
47+
return None
48+
49+
target_atoms = node.target.split(".")
50+
attr_itr = graph_module
51+
for atom in target_atoms:
52+
if not hasattr(attr_itr, atom):
53+
return None
54+
attr_itr = getattr(attr_itr, atom)
55+
return attr_itr
56+
57+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
58+
def _is_batch_norm(node_: Node) -> bool:
59+
return (
60+
node_.op == "call_function"
61+
and node_.target == torch.ops.aten.batch_norm.default
62+
)
63+
64+
def _is_conv(node_: Node):
65+
return node_.op == "call_function" and node_.target in (
66+
torch.ops.aten.conv1d.default,
67+
torch.ops.aten.conv2d.default,
68+
)
69+
70+
made_changes = False
71+
72+
if not any(map(_is_batch_norm, graph_module.graph.nodes)):
73+
return PassResult(
74+
graph_module, made_changes
75+
) # No batch norm nodes in the model.
76+
77+
for node in graph_module.graph.nodes:
78+
if not _is_batch_norm(node):
79+
continue # Not BatchNorm.
80+
81+
bn_node = node
82+
83+
if not _is_conv(bn_node.args[0]):
84+
continue # Something other than a Conv node comes before the BatchNorm.
85+
86+
conv_node = bn_node.args[0]
87+
conv_weight_node = conv_node.args[1]
88+
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
89+
90+
# conv args: input, weight, bias, stride, padding, dilation, ...
91+
conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node)
92+
conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node)
93+
94+
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
95+
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
96+
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
97+
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
98+
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
99+
bn_eps = bn_node.args[7]
100+
101+
if any(
102+
t is None for t in (conv_w, bn_rm, bn_rv)
103+
): # The other inputs can be None.
104+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
105+
fused_weight, fused_bias = fuse_conv_bn_weights(
106+
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
107+
)
108+
109+
# Update the weight and bias for Conv.
110+
conv_args = list(conv_node.args)
111+
if len(conv_args) == 2:
112+
# Fill in the default bias argument.
113+
conv_args.append(None)
114+
115+
weight_attr_name = conv_weight_node.target
116+
_assign_attr(
117+
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
118+
)
119+
120+
if conv_bias_node is not None:
121+
bias_attr_name = conv_bias_node.target
122+
_assign_attr(
123+
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
124+
)
125+
else:
126+
# The Conv doesn't have a bias. Create a new one.
127+
bias_attr_name = weight_attr_name + "_bias"
128+
_assign_attr(
129+
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
130+
)
131+
with graph_module.graph.inserting_before(conv_node):
132+
get_bias_node = graph_module.graph.get_attr(bias_attr_name)
133+
134+
conv_args[2] = get_bias_node
135+
136+
conv_node.args = tuple(conv_args)
137+
138+
# Replace the uses of the BatchNorm with the Conv.
139+
bn_node.replace_all_uses_with(conv_node)
140+
141+
made_changes = True
142+
143+
return PassResult(graph_module, made_changes)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Optional
6+
7+
import torch
8+
from torch.export.unflatten import _assign_attr, _AttrKind
9+
from torch.fx import GraphModule, Node
10+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
11+
from torch.nn.parameter import Parameter
12+
from torch.nn.utils import fuse_linear_bn_weights
13+
14+
15+
class FuseBatchNormWithLinearPass(PassBase):
16+
"""The executorch batch normalization carries out the following computation [1].
17+
18+
(x - mean) / (var + eps) * W + B
19+
20+
Which can be expressed as
21+
22+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
23+
24+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
25+
and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale
26+
and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm
27+
to be completely removed.
28+
29+
30+
31+
┌──────▼──────┐
32+
│ aten.linear │
33+
└──────┬──────┘
34+
│ │
35+
┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐
36+
│ aten.batch_norm │ ──────────────► │ aten.linear │
37+
└─────────────────────┬─────────────────────┘ └──────┬──────┘
38+
39+
40+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
41+
"""
42+
43+
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
44+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
45+
if node is None or node.op != "get_attr":
46+
return None
47+
48+
target_atoms = node.target.split(".")
49+
attr_itr = graph_module
50+
for atom in target_atoms:
51+
if not hasattr(attr_itr, atom):
52+
return None
53+
attr_itr = getattr(attr_itr, atom)
54+
return attr_itr
55+
56+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
57+
def _is_batch_norm(node_: Node) -> bool:
58+
return (
59+
node_.op == "call_function"
60+
and node_.target == torch.ops.aten.batch_norm.default
61+
)
62+
63+
def _is_linear(node_: Node):
64+
return (
65+
node_.op == "call_function"
66+
and node_.target == torch.ops.aten.linear.default
67+
)
68+
69+
made_changes = False
70+
71+
if not any(map(_is_batch_norm, graph_module.graph.nodes)):
72+
return PassResult(
73+
graph_module, made_changes
74+
) # No batch norm nodes in the model.
75+
76+
for node in graph_module.graph.nodes:
77+
if not _is_batch_norm(node):
78+
continue # Not BatchNorm.
79+
80+
bn_node = node
81+
82+
if not _is_linear(bn_node.args[0]):
83+
continue # Something other than a Linear node comes before the BatchNorm.
84+
85+
linear_node = bn_node.args[0]
86+
linear_weight_node = linear_node.args[1]
87+
linear_bias_node = (
88+
linear_node.args[2] if len(linear_node.args) > 2 else None
89+
)
90+
91+
linear_w = self._get_tensor_constant_from_node(
92+
graph_module, linear_weight_node
93+
)
94+
linear_b = self._get_tensor_constant_from_node(
95+
graph_module, linear_bias_node
96+
)
97+
98+
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
99+
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
100+
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
101+
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
102+
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
103+
bn_eps = bn_node.args[7]
104+
105+
if any(
106+
t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv)
107+
): # The Linear bias can be None.
108+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
109+
fused_weight, fused_bias = fuse_linear_bn_weights(
110+
linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
111+
)
112+
113+
# Update the weight and bias for Linear.
114+
linear_args = list(linear_node.args)
115+
if len(linear_args) == 2:
116+
# Fill in the default bias argument.
117+
linear_args.append(None)
118+
119+
weight_attr_name = linear_weight_node.target
120+
_assign_attr(
121+
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
122+
)
123+
124+
if linear_bias_node is not None:
125+
bias_attr_name = linear_bias_node.target
126+
_assign_attr(
127+
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
128+
)
129+
else:
130+
# The Linear doesn't have a bias. Create a new one.
131+
bias_attr_name = weight_attr_name + "_bias"
132+
_assign_attr(
133+
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
134+
)
135+
with graph_module.graph.inserting_before(linear_node):
136+
get_bias_node = graph_module.graph.get_attr(bias_attr_name)
137+
138+
linear_args[2] = get_bias_node
139+
140+
linear_node.args = tuple(linear_args)
141+
142+
# Replace the uses of the BatchNorm with the Linear.
143+
bn_node.replace_all_uses_with(linear_node)
144+
145+
made_changes = True
146+
147+
return PassResult(graph_module, made_changes)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable
7+
8+
import torch
9+
10+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
11+
FuseBatchNormWithConvPass,
12+
)
13+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
14+
FuseBatchNormWithLinearPass,
15+
)
16+
from executorch.exir.pass_manager import PassManager
17+
from torch import nn
18+
from torch.fx.passes.infra.pass_base import PassResult
19+
20+
PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]]
21+
22+
23+
class NeutronAtenPassManager(PassManager):
24+
25+
def __init__(self, passes: list[PassType] = None):
26+
passes: list[PassType] = passes or [
27+
FuseBatchNormWithConvPass(),
28+
FuseBatchNormWithLinearPass(),
29+
]
30+
31+
super().__init__(passes)
32+
33+
def __call__(self, module: nn.Module) -> PassResult:
34+
pass_result: PassResult = super().__call__(module)
35+
36+
graph_module = pass_result.graph_module
37+
graph_module.graph.eliminate_dead_code()
38+
graph_module.recompile()
39+
40+
return pass_result

0 commit comments

Comments
 (0)