Skip to content

Commit 4a27a53

Browse files
authored
[XNNPACK][Partitioner] enable src based partitioner (#4795)
In order to maintain parity with the current to_edge and to_backend lowering flow, we need to support source based partitioning. We apply source-based partitioning to the AddMMConfig to partition all the nodes surrounding addmm so that it can be recomposed internally. While this is fine for to_edge and to_backend flow. For more robust flow, we will not have to use this when running to_edge_transform_and_lower. Co-authored-by: Max Ren <[email protected]> Pull Request resolved: #4762
1 parent 48d664c commit 4a27a53

File tree

5 files changed

+199
-79
lines changed

5 files changed

+199
-79
lines changed

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AddmmConfig,
1212
ConvolutionConfig,
1313
LinearConfig,
14+
MMConfig,
1415
)
1516

1617
from executorch.backends.xnnpack.partition.config.generic_node_configs import (
@@ -79,6 +80,7 @@
7980
MaxPool2dConfig,
8081
MeanDimConfig,
8182
MinimumConfig,
83+
MMConfig,
8284
MulConfig,
8385
NegConfig,
8486
PermuteConfig,

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 135 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
format_target_name,
3030
)
3131
from torch.export import ExportedProgram
32+
from torch.fx.passes.utils.source_matcher_utils import (
33+
get_source_partitions,
34+
SourcePartition,
35+
)
3236

3337

3438
class GEMMConfig(XNNPartitionerConfig):
@@ -52,20 +56,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
5256
# short circuit if we don't pass common constraints
5357
return False
5458

55-
precision = self._detect_precision(node)
56-
if precision not in self.enabled_precision_types:
57-
# detected precision but it is either disabled or not supported
58-
return False
59-
60-
is_valid, _ = self.get_deps(node, ep, precision)
59+
is_valid, _ = self.get_deps(node, ep)
6160
return is_valid
6261

6362
def get_node_and_deps(
6463
self, node: torch.fx.Node, ep: ExportedProgram
6564
) -> List[torch.fx.Node]:
6665
partition = [node]
67-
precision = self._detect_precision(node)
68-
_, deps = self.get_deps(node, ep, precision)
66+
_, deps = self.get_deps(node, ep)
6967
partition.extend(deps)
7068

7169
return partition
@@ -86,13 +84,20 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
8684
return ConfigPrecisionType.STATIC_QUANT
8785

8886
def get_deps(
89-
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
87+
self,
88+
node: torch.fx.Node,
89+
ep: ExportedProgram,
9090
) -> Tuple[bool, List[torch.fx.Node]]:
9191
"""
9292
Gets all dependencies for this gemm partition. Returns a tuple of
9393
a bool indicating if the deps are valid and a list of all the
9494
dep nodes
9595
"""
96+
precision = self._detect_precision(node)
97+
if precision not in self.supported_precision_types():
98+
# detected precision but it is either disabled or not supported
99+
return (False, [])
100+
96101
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
97102
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
98103
valid_act, act_deps = self._get_act_deps(node, ep, precision)
@@ -178,7 +183,7 @@ def _get_bias_deps(
178183
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
179184
) -> Tuple[bool, List[torch.fx.Node]]:
180185
gemm_deps = []
181-
if len(node.all_input_nodes) > 2:
186+
if len(node.all_input_nodes) > 2 and self.bias_idx:
182187
bias_node = get_input_node(node, self.bias_idx)
183188
if bias_node:
184189
if not is_param_node(ep, bias_node):
@@ -251,11 +256,16 @@ def supported_precision_types(self):
251256
]
252257

253258

254-
class AddmmConfig(GEMMConfig):
255-
target_name = "addmm.default"
259+
class ConvolutionConfig(GEMMConfig):
260+
target_name = "convolution.default"
256261

257262
def __init__(self):
258-
super().__init__(weight_idx=2, bias_idx=0, act_idx=1, fused_acts=[])
263+
super().__init__(
264+
weight_idx=1,
265+
bias_idx=2,
266+
act_idx=0,
267+
fused_acts=["relu.default", "hardtanh.default"],
268+
)
259269

260270
def supported_precision_types(self):
261271
return [
@@ -264,19 +274,126 @@ def supported_precision_types(self):
264274
]
265275

266276

267-
class ConvolutionConfig(GEMMConfig):
268-
target_name = "convolution.default"
277+
class AddmmConfig(GEMMConfig):
278+
"""
279+
We will handle the legacy form of addmm partitioning which will include
280+
partitioning using source partitions.
281+
"""
282+
283+
target_name = "addmm.default"
269284

270285
def __init__(self):
271286
super().__init__(
272-
weight_idx=1,
273-
bias_idx=2,
274-
act_idx=0,
287+
weight_idx=2,
288+
bias_idx=0,
289+
act_idx=1,
275290
fused_acts=["relu.default", "hardtanh.default"],
276291
)
292+
self.src_partitions = None
293+
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
294+
295+
def get_deps(
296+
self,
297+
node: torch.fx.Node,
298+
ep: ExportedProgram,
299+
) -> Tuple[bool, List[torch.fx.Node]]:
300+
"""
301+
Gets all dependencies for this gemm partition. Returns a tuple of
302+
a bool indicating if the deps are valid and a list of all the
303+
dep nodes. This handles the src partition for
304+
"""
305+
if self.src_partitions is None:
306+
# Cache src partitions so we don't have to recompute them every time
307+
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)
308+
309+
# src_partition is None if node is not in source partition,
310+
# otherwise gives us the linear source partition it belongs to
311+
src_partition = None
312+
for partition_list in self.src_partitions.values():
313+
for partition in partition_list:
314+
if node in partition.nodes:
315+
src_partition = partition
316+
317+
if src_partition:
318+
# if addmm belongs to linear src partition, then partition the
319+
# src partition and get its deps
320+
return self.get_deps_from_src_partition(node, ep, src_partition)
321+
322+
return super().get_deps(node, ep)
323+
324+
def get_deps_from_src_partition(
325+
self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition
326+
):
327+
"""
328+
Gets all the dependencies for the src partition. This is done by simulating the
329+
linear node from the src partition. We find the associated weights, act, bias
330+
from the linear src partition, and plug those in as the addmm node's args. We also
331+
take the users of the src partitions output node as the addmm node's users. Finally
332+
we just run the GEMMConfig's get_deps method no this faked linear node. After
333+
getting the deps, we return the addmm nodes users and args back.
334+
"""
335+
336+
def find_partition_args(input_node):
337+
while (
338+
len(input_node.all_input_nodes) != 0
339+
and input_node not in src_partition.input_nodes
340+
):
341+
input_node = input_node.all_input_nodes[0]
342+
return input_node
343+
344+
old_args, old_users = node.args, node.users
345+
346+
fake_args = []
347+
for arg in node.args:
348+
# map addmm's args to the source partition's inputs
349+
# basically simulating what the args of the linear node would be
350+
fake_args.append(find_partition_args(arg))
351+
352+
# validate source partition
353+
if (
354+
# bias must be in source partition
355+
(self.bias_idx and fake_args[self.bias_idx] not in src_partition.nodes)
356+
# activation input must be an input node to partition
357+
or fake_args[self.act_idx] not in src_partition.input_nodes
358+
# weight can either be in the nodes or input_nodes
359+
or fake_args[self.weight_idx]
360+
not in (src_partition.nodes + src_partition.input_nodes)
361+
# there can only be a single output node in partition
362+
or len(src_partition.output_nodes) != 1
363+
):
364+
return (False, [])
365+
366+
# map addmm's args to the source partition linear's inputs and users
367+
node.args = tuple(fake_args)
368+
node.users = src_partition.output_nodes[0].users
369+
valid_deps, deps = super().get_deps(node, ep)
370+
371+
# Reset addmm node back to old args and users
372+
node.args = old_args
373+
node.users = old_users
374+
375+
return valid_deps, list(set(deps) | set(src_partition.nodes))
277376

278377
def supported_precision_types(self):
279378
return [
280379
ConfigPrecisionType.FP32,
281380
ConfigPrecisionType.STATIC_QUANT,
381+
ConfigPrecisionType.DYNAMIC_QUANT,
382+
]
383+
384+
385+
class MMConfig(AddmmConfig):
386+
target_name = "mm.default"
387+
388+
def __init__(self):
389+
super().__init__()
390+
self.bias_idx = None
391+
self.weight_idx = 1
392+
self.act_idx = 0
393+
394+
def supported_precision_types(self):
395+
return [
396+
ConfigPrecisionType.FP32,
397+
ConfigPrecisionType.STATIC_QUANT,
398+
ConfigPrecisionType.DYNAMIC_QUANT,
282399
]

backends/xnnpack/partition/config/node_configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
FuseBatchNormWithConvPass,
1717
)
1818
from executorch.backends.xnnpack.utils.utils import is_param_node
19+
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
20+
format_target_name,
21+
)
1922
from torch.export import ExportedProgram
2023

2124

@@ -29,6 +32,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
2932
bn = node
3033
conv = node.all_input_nodes[0]
3134

35+
if conv.op != "call_function":
36+
return False
37+
38+
conv_name = format_target_name(conv.target.__name__) # pyre-ignore
39+
40+
if conv_name not in ["convolution.default"]:
41+
return False
42+
3243
return FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
3344

3445
def get_node_and_deps(

0 commit comments

Comments
 (0)