Skip to content

Commit 30732fe

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add constraint to not partition standalone batch norm (#1501)
Summary: Pull Request resolved: #1501 The XNNPACK backend does not current support lowering standalone (non-fused) batch norms. This will be done in the near future, but models with standalone batch norms are failing to lower as the op is partitioned but cannot be lowered. This change adds a op-level constraint for batch norm to the XNNPACK partitioner to only partition batch norms that can be fused. This constraint will be relaxed once standalone batch norm is fully supported. Reviewed By: mcr229 Differential Revision: D52491544 fbshipit-source-id: 861744d836b0cbfc07700bc411e9677ba80367df
1 parent 5318baa commit 30732fe

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
UNSUPPORTED_QUANT_MODULES,
2323
)
2424
from executorch.backends.xnnpack.partition.graphs.bilinear_2d import bilinear2d_graphs
25+
from executorch.backends.xnnpack.passes.fuse_batch_norm_with_conv import (
26+
FuseBatchNormWithConvPass,
27+
)
2528
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
2629
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
2730

@@ -374,6 +377,23 @@ def amax(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
374377
dim_arg_val = cast(int, node.args[1])
375378
return is_keep_dim and (dim_arg_val == 2 or dim_arg_val == 3)
376379

380+
@_constraint(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
381+
def batch_norm(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
382+
"""
383+
Only support batch norms that can be fused with convolutions.
384+
This will be removed once standalone batch norm is supported.
385+
"""
386+
387+
# TODO(gjcomer) Remove after standalone batch norm (T171796544).
388+
389+
conv_node = node.args[0]
390+
assert isinstance(conv_node, torch.fx.Node)
391+
392+
if conv_node.target != exir_ops.edge.aten.convolution.default:
393+
return False
394+
395+
return FuseBatchNormWithConvPass.can_fuse(conv_node, node, ep)
396+
377397

378398
class XnnpackFloatingPointPartitioner(Partitioner):
379399
"""

backends/xnnpack/passes/fuse_batch_norm_with_conv.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
1212

1313
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
14+
from executorch.exir import ExportedProgram
1415
from executorch.exir.dialects._ops import ops as exir_ops
1516
from executorch.exir.pass_base import PassResult
1617

@@ -21,9 +22,8 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
2122
"""
2223
Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase
2324
memory usage since we serialize new weights to represent the convolution. In most cases,
24-
Batch norm is used after convoluution. The 1x1 depthwise convolution can then be fused
25+
Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused
2526
with the previous convolution
26-
2727
"""
2828

2929
def call(self, graph_module: torch.fx.GraphModule):
@@ -48,20 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4848
):
4949
continue
5050

51-
# All the users of batchnorm node must be getitem ops. batchnorm
52-
# returns a 3-element tuple. Each user must only access the first
53-
# element of the tuple.
54-
if [
55-
(user.target == operator.getitem and user.args[1] == 0)
56-
for user in bn.users
57-
].count(False):
58-
continue
59-
60-
# Check that the weights for conv and batchnorm are both params
61-
if [
62-
is_param_node(self.exported_program, node)
63-
for node in {conv.args[1], bn.args[1]}
64-
].count(False):
51+
if not self.can_fuse(conv, bn, self.exported_program):
6552
continue
6653

6754
# Get the parameters from conv op
@@ -138,3 +125,35 @@ def call(self, graph_module: torch.fx.GraphModule):
138125
graph_module = super().call(graph_module).graph_module
139126

140127
return PassResult(graph_module, True)
128+
129+
@staticmethod
130+
def can_fuse(
131+
conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
132+
) -> bool:
133+
"""
134+
Determine whether a batch norm node can be fused with a preceding conv node.
135+
"""
136+
137+
# All the users of batchnorm node must be getitem ops. batchnorm
138+
# returns a 3-element tuple. Each user must only access the first
139+
# element of the tuple.
140+
if [
141+
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
142+
].count(False):
143+
return False
144+
145+
conv_weights = conv.args[1]
146+
bn_weights = bn.args[1]
147+
148+
# Check that the weights for conv and batchnorm are both params
149+
if not isinstance(conv_weights, torch.fx.Node) or not isinstance(
150+
bn_weights, torch.fx.Node
151+
):
152+
return False
153+
154+
if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count(
155+
False
156+
):
157+
return False
158+
159+
return True

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,27 @@ def test_q8_batch_norm_fusion(self):
5555
.run_method()
5656
.compare_outputs()
5757
)
58+
59+
def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
60+
"""
61+
We do not currently support standalone batch norms (i.e. batch norms that are
62+
not fused with a conv). This is planned, but until implemented, this test ensures
63+
that we do not partition the standalone batch norm and then fail to lower.
64+
"""
65+
66+
class BN(torch.nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
self.bn = torch.nn.BatchNorm2d(2)
70+
71+
def forward(self, x):
72+
return self.bn(x)
73+
74+
(
75+
Tester(BN(), (torch.randn(2, 2, 4, 4),))
76+
.export()
77+
.to_edge()
78+
.check_count({self.bn_name: 1})
79+
.partition()
80+
.check_count({self.bn_name: 1})
81+
)

0 commit comments

Comments
 (0)