Skip to content

[XNNPACK][Partitioner] enable src based partitioner #4795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AddmmConfig,
ConvolutionConfig,
LinearConfig,
MMConfig,
)

from executorch.backends.xnnpack.partition.config.generic_node_configs import (
Expand Down Expand Up @@ -79,6 +80,7 @@
MaxPool2dConfig,
MeanDimConfig,
MinimumConfig,
MMConfig,
MulConfig,
NegConfig,
PermuteConfig,
Expand Down
153 changes: 135 additions & 18 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
format_target_name,
)
from torch.export import ExportedProgram
from torch.fx.passes.utils.source_matcher_utils import (
get_source_partitions,
SourcePartition,
)


class GEMMConfig(XNNPartitionerConfig):
Expand All @@ -52,20 +56,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
# short circuit if we don't pass common constraints
return False

precision = self._detect_precision(node)
if precision not in self.enabled_precision_types:
# detected precision but it is either disabled or not supported
return False

is_valid, _ = self.get_deps(node, ep, precision)
is_valid, _ = self.get_deps(node, ep)
return is_valid

def get_node_and_deps(
self, node: torch.fx.Node, ep: ExportedProgram
) -> List[torch.fx.Node]:
partition = [node]
precision = self._detect_precision(node)
_, deps = self.get_deps(node, ep, precision)
_, deps = self.get_deps(node, ep)
partition.extend(deps)

return partition
Expand All @@ -86,13 +84,20 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
return ConfigPrecisionType.STATIC_QUANT

def get_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
self,
node: torch.fx.Node,
ep: ExportedProgram,
) -> Tuple[bool, List[torch.fx.Node]]:
"""
Gets all dependencies for this gemm partition. Returns a tuple of
a bool indicating if the deps are valid and a list of all the
dep nodes
"""
precision = self._detect_precision(node)
if precision not in self.supported_precision_types():
# detected precision but it is either disabled or not supported
return (False, [])

valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
valid_act, act_deps = self._get_act_deps(node, ep, precision)
Expand Down Expand Up @@ -178,7 +183,7 @@ def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
gemm_deps = []
if len(node.all_input_nodes) > 2:
if len(node.all_input_nodes) > 2 and self.bias_idx:
bias_node = get_input_node(node, self.bias_idx)
if bias_node:
if not is_param_node(ep, bias_node):
Expand Down Expand Up @@ -251,11 +256,16 @@ def supported_precision_types(self):
]


class AddmmConfig(GEMMConfig):
target_name = "addmm.default"
class ConvolutionConfig(GEMMConfig):
target_name = "convolution.default"

def __init__(self):
super().__init__(weight_idx=2, bias_idx=0, act_idx=1, fused_acts=[])
super().__init__(
weight_idx=1,
bias_idx=2,
act_idx=0,
fused_acts=["relu.default", "hardtanh.default"],
)

def supported_precision_types(self):
return [
Expand All @@ -264,19 +274,126 @@ def supported_precision_types(self):
]


class ConvolutionConfig(GEMMConfig):
target_name = "convolution.default"
class AddmmConfig(GEMMConfig):
"""
We will handle the legacy form of addmm partitioning which will include
partitioning using source partitions.
"""

target_name = "addmm.default"

def __init__(self):
super().__init__(
weight_idx=1,
bias_idx=2,
act_idx=0,
weight_idx=2,
bias_idx=0,
act_idx=1,
fused_acts=["relu.default", "hardtanh.default"],
)
self.src_partitions = None
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]

def get_deps(
self,
node: torch.fx.Node,
ep: ExportedProgram,
) -> Tuple[bool, List[torch.fx.Node]]:
"""
Gets all dependencies for this gemm partition. Returns a tuple of
a bool indicating if the deps are valid and a list of all the
dep nodes. This handles the src partition for
"""
if self.src_partitions is None:
# Cache src partitions so we don't have to recompute them every time
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)

# src_partition is None if node is not in source partition,
# otherwise gives us the linear source partition it belongs to
src_partition = None
for partition_list in self.src_partitions.values():
for partition in partition_list:
if node in partition.nodes:
src_partition = partition

if src_partition:
# if addmm belongs to linear src partition, then partition the
# src partition and get its deps
return self.get_deps_from_src_partition(node, ep, src_partition)

return super().get_deps(node, ep)

def get_deps_from_src_partition(
self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition
):
"""
Gets all the dependencies for the src partition. This is done by simulating the
linear node from the src partition. We find the associated weights, act, bias
from the linear src partition, and plug those in as the addmm node's args. We also
take the users of the src partitions output node as the addmm node's users. Finally
we just run the GEMMConfig's get_deps method no this faked linear node. After
getting the deps, we return the addmm nodes users and args back.
"""

def find_partition_args(input_node):
while (
len(input_node.all_input_nodes) != 0
and input_node not in src_partition.input_nodes
):
input_node = input_node.all_input_nodes[0]
return input_node

old_args, old_users = node.args, node.users

fake_args = []
for arg in node.args:
# map addmm's args to the source partition's inputs
# basically simulating what the args of the linear node would be
fake_args.append(find_partition_args(arg))

# validate source partition
if (
# bias must be in source partition
(self.bias_idx and fake_args[self.bias_idx] not in src_partition.nodes)
# activation input must be an input node to partition
or fake_args[self.act_idx] not in src_partition.input_nodes
# weight can either be in the nodes or input_nodes
or fake_args[self.weight_idx]
not in (src_partition.nodes + src_partition.input_nodes)
# there can only be a single output node in partition
or len(src_partition.output_nodes) != 1
):
return (False, [])

# map addmm's args to the source partition linear's inputs and users
node.args = tuple(fake_args)
node.users = src_partition.output_nodes[0].users
valid_deps, deps = super().get_deps(node, ep)

# Reset addmm node back to old args and users
node.args = old_args
node.users = old_users

return valid_deps, list(set(deps) | set(src_partition.nodes))

def supported_precision_types(self):
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]


class MMConfig(AddmmConfig):
target_name = "mm.default"

def __init__(self):
super().__init__()
self.bias_idx = None
self.weight_idx = 1
self.act_idx = 0

def supported_precision_types(self):
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]
11 changes: 11 additions & 0 deletions backends/xnnpack/partition/config/node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from torch.export import ExportedProgram


Expand All @@ -29,6 +32,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
bn = node
conv = node.all_input_nodes[0]

if conv.op != "call_function":
return False

conv_name = format_target_name(conv.target.__name__) # pyre-ignore

if conv_name not in ["convolution.default"]:
return False

return FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)

def get_node_and_deps(
Expand Down
Loading
Loading