Skip to content

Add constraint to not partition standalone batch norm #1501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
UNSUPPORTED_QUANT_MODULES,
)
from executorch.backends.xnnpack.partition.graphs.bilinear_2d import bilinear2d_graphs
from executorch.backends.xnnpack.passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend

Expand Down Expand Up @@ -374,6 +377,23 @@ def amax(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
dim_arg_val = cast(int, node.args[1])
return is_keep_dim and (dim_arg_val == 2 or dim_arg_val == 3)

@_constraint(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
def batch_norm(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
"""
Only support batch norms that can be fused with convolutions.
This will be removed once standalone batch norm is supported.
"""

# TODO(gjcomer) Remove after standalone batch norm (T171796544).

conv_node = node.args[0]
assert isinstance(conv_node, torch.fx.Node)

if conv_node.target != exir_ops.edge.aten.convolution.default:
return False

return FuseBatchNormWithConvPass.can_fuse(conv_node, node, ep)


class XnnpackFloatingPointPartitioner(Partitioner):
"""
Expand Down
51 changes: 35 additions & 16 deletions backends/xnnpack/passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass

from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult

Expand All @@ -21,9 +22,8 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
"""
Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase
memory usage since we serialize new weights to represent the convolution. In most cases,
Batch norm is used after convoluution. The 1x1 depthwise convolution can then be fused
Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused
with the previous convolution

"""

def call(self, graph_module: torch.fx.GraphModule):
Expand All @@ -48,20 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
):
continue

# All the users of batchnorm node must be getitem ops. batchnorm
# returns a 3-element tuple. Each user must only access the first
# element of the tuple.
if [
(user.target == operator.getitem and user.args[1] == 0)
for user in bn.users
].count(False):
continue

# Check that the weights for conv and batchnorm are both params
if [
is_param_node(self.exported_program, node)
for node in {conv.args[1], bn.args[1]}
].count(False):
if not self.can_fuse(conv, bn, self.exported_program):
continue

# Get the parameters from conv op
Expand Down Expand Up @@ -138,3 +125,35 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

@staticmethod
def can_fuse(
conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
) -> bool:
"""
Determine whether a batch norm node can be fused with a preceding conv node.
"""

# All the users of batchnorm node must be getitem ops. batchnorm
# returns a 3-element tuple. Each user must only access the first
# element of the tuple.
if [
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
].count(False):
return False

conv_weights = conv.args[1]
bn_weights = bn.args[1]

# Check that the weights for conv and batchnorm are both params
if not isinstance(conv_weights, torch.fx.Node) or not isinstance(
bn_weights, torch.fx.Node
):
return False

if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count(
False
):
return False

return True
24 changes: 24 additions & 0 deletions backends/xnnpack/test/passes/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,27 @@ def test_q8_batch_norm_fusion(self):
.run_method()
.compare_outputs()
)

def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
"""
We do not currently support standalone batch norms (i.e. batch norms that are
not fused with a conv). This is planned, but until implemented, this test ensures
that we do not partition the standalone batch norm and then fail to lower.
"""

class BN(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(2)

def forward(self, x):
return self.bn(x)

(
Tester(BN(), (torch.randn(2, 2, 4, 4),))
.export()
.to_edge()
.check_count({self.bn_name: 1})
.partition()
.check_count({self.bn_name: 1})
)