Skip to content

Commit 69d8b43

Browse files
cccclaifacebook-github-bot
authored andcommitted
add duplicate constant node pass (#2570)
Summary: This diff is the follow up for #2424 In the case like ``` consant_0 (tag_10) ----> op_b (tag_10) |-------------> op_a (tag_11) ``` `op_b` and `op_a` are in two delegated payload and `constant_0` have two options: In this diff, we're making the default behavior as allowing copying, meaning it will become ``` consant_0 (tag_10)------------------> op_b (tag_10) consant_0_copy (tag_11) -------------> op_a (tag_11) ``` The backend can tag the node with `no_copy` to allert users in cases like constants are too large or etc. In this case, a better approach can be ``` consant_0 (tag_10) ----> op_b (tag_10) |-----(output consant_0) --------> op_a (tag_11) ``` Reviewed By: angelayi Differential Revision: D55113232
1 parent c8f2d8d commit 69d8b43

File tree

9 files changed

+379
-1
lines changed

9 files changed

+379
-1
lines changed

exir/backend/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ runtime.python_library(
2626
":compile_spec_schema",
2727
"//caffe2:torch",
2828
"//executorch/exir/backend:utils",
29+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
2930
],
3031
)
3132

exir/backend/backend_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717

1818
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
19-
from executorch.exir.backend.utils import is_identical_graph
19+
from executorch.exir.backend.utils import (
20+
_maybe_duplicate_constant_nodes,
21+
is_identical_graph,
22+
)
2023

2124
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
2225

@@ -160,6 +163,7 @@ def _get_node_list_with_same_tag(
160163
Return a list of nodes with the same tag.
161164
"""
162165
node_list = []
166+
163167
for node in tagged_graph_module.graph.nodes:
164168
if node.meta.get("delegation_tag", "") == tag:
165169
if node.op == "output":
@@ -373,6 +377,10 @@ def to_backend(
373377
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
374378

375379
update_to_real_program(tagged_exported_program, edge_program)
380+
381+
for tag, _ in partitioner_result.partition_tags.items():
382+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
383+
376384
tagged_graph_module = _partition_and_lower(
377385
tagged_exported_program.graph_module, partitioner_result, edge_program
378386
)

exir/backend/canonical_partitioners/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ runtime.python_library(
1919
"//executorch/exir/backend:partitioner",
2020
],
2121
)
22+
23+
runtime.python_library(
24+
name = "duplicate_constant_node_pass",
25+
srcs = [
26+
"duplicate_constant_node_pass.py",
27+
],
28+
visibility = [
29+
"//executorch/...",
30+
"//executorch/exir/backend/...",
31+
"//executorch/test/...",
32+
"@EXECUTORCH_CLIENTS",
33+
],
34+
deps = [
35+
"//caffe2:torch",
36+
"//executorch/exir/backend:partitioner",
37+
],
38+
)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import logging
9+
from typing import Optional, Set
10+
11+
import torch
12+
from torch._export.utils import get_buffer, get_lifted_tensor_constant, get_param
13+
14+
from torch.export import ExportedProgram
15+
from torch.export.exported_program import InputSpec, TensorArgument
16+
from torch.export.graph_signature import InputKind
17+
18+
19+
def _get_attribute_or_constants(
20+
exported_program: ExportedProgram, node: torch.fx.Node
21+
) -> Optional[torch.Tensor]:
22+
# get either attribute node or constant constant
23+
maybe_param = get_param(exported_program, node)
24+
maybe_buffer = get_buffer(exported_program, node)
25+
maybe_lifted_tensor = get_lifted_tensor_constant(exported_program, node)
26+
27+
constant_or_attribute = None
28+
if maybe_param is not None:
29+
constant_or_attribute = maybe_param
30+
elif maybe_buffer is not None:
31+
constant_or_attribute = maybe_buffer
32+
elif maybe_lifted_tensor is not None:
33+
constant_or_attribute = maybe_lifted_tensor
34+
return constant_or_attribute
35+
36+
37+
# TODO: add other passes to duplicate call_function nodes
38+
def duplicate_constant_node(
39+
exported_program: ExportedProgram, candidate_node: str
40+
) -> Set[str]:
41+
"""
42+
A pass to duplicate the attributes/constants node (the candidate_node) in the graph. Mostly used for duplicating light-weight data.
43+
If the data is too large, try tag it with "no_copy" to prevent high memory usage and make it as part of the output.
44+
45+
Args:
46+
exported_program: the exported program to be modified. If constants nodes are copied, they will be added as new
47+
placeholder and the state_dict will be updated
48+
candidate_node: the name of the constant node to be duplicated
49+
50+
Returns:
51+
The set of the names of the new constant nodes
52+
"""
53+
to_be_copied = [
54+
node
55+
for node in exported_program.graph.nodes
56+
if node.name == candidate_node and node.op == "placeholder"
57+
]
58+
if len(to_be_copied) == 0:
59+
logging.info("no constant node to be copied")
60+
return set()
61+
new_input_specs = []
62+
old_signature = exported_program.graph_signature
63+
copied_nodes = set()
64+
for idx, node in enumerate(exported_program.graph.nodes):
65+
if node.op != "placeholder":
66+
continue
67+
old_input_spec = old_signature.input_specs[idx]
68+
old_input_spec_copy = copy.deepcopy(old_input_spec)
69+
if node == to_be_copied[0]:
70+
constant_or_attribute_node = node
71+
constant_or_attribute = _get_attribute_or_constants(exported_program, node)
72+
if constant_or_attribute is None:
73+
raise RuntimeError(
74+
f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} "
75+
)
76+
users = list(node.users.keys())
77+
for ith in range(1, len(node.users)):
78+
copy_constant_or_attribute_fqn = node.name + f"_copy_{ith - 1}"
79+
with exported_program.graph.inserting_before(
80+
constant_or_attribute_node
81+
):
82+
copied_constant_or_attribute_node = (
83+
exported_program.graph.placeholder(
84+
copy_constant_or_attribute_fqn
85+
)
86+
)
87+
copied_nodes.add(copy_constant_or_attribute_fqn)
88+
logging.info(
89+
f"Copying constant nodes {node.name} and creating {copy_constant_or_attribute_fqn}"
90+
)
91+
for k, v in node.meta.items():
92+
copied_constant_or_attribute_node.meta[k] = v
93+
copied_constant_or_attribute_node.meta["val"] = (
94+
constant_or_attribute_node.meta["val"]
95+
)
96+
new_args = tuple(
97+
[
98+
(
99+
arg
100+
if arg != constant_or_attribute_node
101+
else copied_constant_or_attribute_node
102+
)
103+
for arg in users[ith].args
104+
]
105+
)
106+
new_kwargs = dict(
107+
{
108+
(
109+
key,
110+
(
111+
value
112+
if value != constant_or_attribute_node
113+
else copied_constant_or_attribute_node
114+
),
115+
)
116+
for key, value in users[ith].kwargs
117+
}
118+
)
119+
users[ith].args = new_args
120+
users[ith].kwargs = new_kwargs
121+
if old_input_spec.kind == InputKind.CONSTANT_TENSOR:
122+
exported_program.constants[copy_constant_or_attribute_fqn] = (
123+
copy.deepcopy(constant_or_attribute)
124+
)
125+
elif (
126+
old_input_spec.kind == InputKind.BUFFER
127+
and old_input_spec.persistent is False
128+
):
129+
# non persistent buffer will be in the .constants
130+
exported_program.constants[copy_constant_or_attribute_fqn] = (
131+
copy.deepcopy(constant_or_attribute)
132+
)
133+
else:
134+
exported_program.state_dict[copy_constant_or_attribute_fqn] = (
135+
copy.deepcopy(constant_or_attribute)
136+
)
137+
new_input_specs.append(
138+
InputSpec(
139+
kind=old_input_spec.kind,
140+
arg=TensorArgument(name=copy_constant_or_attribute_fqn),
141+
target=old_input_spec.target,
142+
persistent=old_input_spec.persistent,
143+
)
144+
)
145+
# Ensure we add the original input spec to the last one, because all the copied nodes
146+
# are inserted before the candidate node.
147+
new_input_specs.append(old_input_spec_copy)
148+
149+
exported_program.graph_signature.input_specs = new_input_specs
150+
exported_program.graph_module.recompile()
151+
exported_program._validate()
152+
return copied_nodes

exir/backend/test/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ python_library(
6767
"//caffe2:torch",
6868
"//executorch/exir:lib",
6969
"//executorch/exir/backend:partitioner",
70+
"//executorch/exir/backend:utils",
7071
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
7172
],
7273
)
@@ -283,6 +284,9 @@ python_unittest(
283284
srcs = [
284285
"test_partitioner.py",
285286
],
287+
preload_deps = [
288+
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
289+
],
286290
deps = [
287291
"//caffe2:torch",
288292
"//executorch/exir:lib",
@@ -295,6 +299,19 @@ python_unittest(
295299
"//executorch/exir/dialects:lib",
296300
"//executorch/exir/tests:models",
297301
"//executorch/extension/pybindings:portable_lib", # @manual
302+
"//executorch/extension/pytree:pylib",
298303
"//executorch/runtime/executor/test:test_backend_compiler_lib",
299304
],
300305
)
306+
307+
python_unittest(
308+
name = "test_passes",
309+
srcs = [
310+
"test_passes.py",
311+
],
312+
deps = [
313+
"//caffe2:torch",
314+
"//executorch/exir:lib",
315+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
316+
],
317+
)

exir/backend/test/demos/rpc/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def define_common_targets():
3939
srcs = [
4040
"ExecutorBackendRegister.cpp",
4141
],
42+
visibility = [
43+
"//executorch/exir/backend/test/...",
44+
],
4245
deps = [
4346
":executor_backend",
4447
"//executorch/runtime/backend:interface",

exir/backend/test/test_partitioner.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

3333
from executorch.exir.tests.models import MLP
34+
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
35+
_load_for_executorch_from_buffer,
36+
)
37+
from executorch.extension.pytree import tree_flatten
3438
from torch._export import capture_pre_autograd_graph
3539
from torch._export.utils import is_buffer, is_param
3640
from torch.export import export
@@ -446,6 +450,66 @@ def partition(
446450
partition_tags=partition_tags,
447451
)
448452

453+
inputs = (torch.ones(2, 2),)
454+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
455+
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
456+
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
457+
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
458+
inputs_flattened, _ = tree_flatten(inputs)
459+
460+
# Send the input from server executor to client executor, and receive the result from client executor
461+
_ = executorch_module.run_method("forward", inputs)
462+
463+
def test_partitioner_alert_split_constant_data(self):
464+
"""
465+
We test that we throw an error when constant data users are split
466+
between different delegated payloads or owning program.
467+
"""
468+
469+
class ReuseConstData(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.const = torch.ones(2, 2)
473+
474+
def forward(self, x):
475+
y = x + self.const
476+
z = x - self.const
477+
return y, z
478+
479+
class PartitionerTagData(Partitioner):
480+
def __init__(self):
481+
super().__init__()
482+
self.delegation_spec = DelegationSpec(
483+
ExecutorBackend.__name__,
484+
[CompileSpec(key, value) for key, value in self.spec.items()],
485+
)
486+
487+
def partition(
488+
self, edge_exported_program: ExportedProgram
489+
) -> PartitionResult:
490+
partition_tags = {}
491+
for node in edge_exported_program.graph.nodes:
492+
if node.op == "call_function" and node.target in [
493+
exir_ops.edge.aten.add.Tensor
494+
]:
495+
delegation_tag = "tag0"
496+
node.meta["delegation_tag"] = delegation_tag
497+
partition_tags[delegation_tag] = self.delegation_spec
498+
499+
if node.op == "placeholder" and (
500+
is_param(edge_exported_program, node)
501+
or is_buffer(edge_exported_program, node)
502+
):
503+
delegation_tag = "tag0"
504+
node.meta["delegation_tag"] = delegation_tag
505+
node.meta["no_copy"] = True
506+
partition_tags[delegation_tag] = self.delegation_spec
507+
508+
return PartitionResult(
509+
tagged_exported_program=edge_exported_program,
510+
partition_tags=partition_tags,
511+
)
512+
449513
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
450514
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
451515
with self.assertRaises(RuntimeError) as error:

exir/backend/test/test_passes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch import exir
11+
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
12+
duplicate_constant_node,
13+
)
14+
from torch._export import capture_pre_autograd_graph
15+
from torch._export.utils import is_buffer
16+
from torch.testing import FileCheck
17+
18+
19+
class TestPasses(unittest.TestCase):
20+
def test_duplicate_constant_node_pass(self):
21+
22+
class ReuseConstData(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.register_buffer("const", torch.ones(2, 2))
26+
27+
def forward(self, x):
28+
y = x + self.const
29+
z = x - self.const
30+
return y, z
31+
32+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
33+
edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))
34+
35+
const_nodes = [
36+
node.name
37+
for node in edge.exported_program().graph.nodes
38+
if node.op == "placeholder" and is_buffer(edge.exported_program(), node)
39+
]
40+
41+
copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0])
42+
self.assertEqual(len(copied_nodes), 1)
43+
44+
# Check that the new constant node is in the graph
45+
FileCheck().check("arg0_1_copy_0").run(
46+
edge.exported_program().graph_module.code
47+
)

0 commit comments

Comments
 (0)