Skip to content

Commit abeff35

Browse files
cccclaifacebook-github-bot
authored andcommitted
add duplicate constant node pass
Summary: This diff is the follow up for #2424 In the case like ``` consant_0 (tag_10) ----> op_b (tag_10) |-------------> op_a (tag_11) ``` `op_b` and `op_a` are in two delegated payload and `constant_0` have two options: In this diff, we're making the default behavior as allowing copying, meaning it will become ``` consant_0 (tag_10)------------------> op_b (tag_10) consant_0_copy (tag_11) -------------> op_a (tag_11) ``` The backend can tag the node with `no_copy` to allert users in cases like constants are too large or etc. In this case, a better approach can be ``` consant_0 (tag_10) ----> op_b (tag_10) |-----(output consant_0) --------> op_a (tag_11) ``` Differential Revision: D55113232
1 parent 1c2ed7b commit abeff35

File tree

9 files changed

+304
-1
lines changed

9 files changed

+304
-1
lines changed

exir/backend/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ runtime.python_library(
2626
":compile_spec_schema",
2727
"//caffe2:torch",
2828
"//executorch/exir/backend:utils",
29+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
2930
],
3031
)
3132

exir/backend/backend_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717

1818
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
19-
from executorch.exir.backend.utils import is_identical_graph
19+
from executorch.exir.backend.utils import (
20+
_maybe_duplicate_constant_nodes,
21+
is_identical_graph,
22+
)
2023

2124
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
2225

@@ -160,6 +163,7 @@ def _get_node_list_with_same_tag(
160163
Return a list of nodes with the same tag.
161164
"""
162165
node_list = []
166+
163167
for node in tagged_graph_module.graph.nodes:
164168
if node.meta.get("delegation_tag", "") == tag:
165169
if node.op == "output":
@@ -371,6 +375,10 @@ def to_backend(
371375
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
372376

373377
update_to_real_program(tagged_exported_program, edge_program)
378+
379+
for tag, _ in partitioner_result.partition_tags.items():
380+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
381+
374382
tagged_graph_module = _partition_and_lower(
375383
tagged_exported_program.graph_module, partitioner_result, edge_program
376384
)

exir/backend/canonical_partitioners/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ runtime.python_library(
1919
"//executorch/exir/backend:partitioner",
2020
],
2121
)
22+
23+
runtime.python_library(
24+
name = "duplicate_constant_node_pass",
25+
srcs = [
26+
"duplicate_constant_node_pass.py",
27+
],
28+
visibility = [
29+
"//executorch/...",
30+
"//executorch/exir/backend/...",
31+
"//executorch/test/...",
32+
"@EXECUTORCH_CLIENTS",
33+
],
34+
deps = [
35+
"//caffe2:torch",
36+
"//executorch/exir/backend:partitioner",
37+
],
38+
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Set
8+
9+
from torch.export import ExportedProgram
10+
from torch.export.exported_program import InputSpec, TensorArgument
11+
12+
13+
def duplicate_constant_node(
14+
exported_program: ExportedProgram, candidate_node: str
15+
) -> Set[str]:
16+
to_be_copied = [
17+
node for node in exported_program.graph.nodes if node.name == candidate_node
18+
]
19+
new_input_specs = []
20+
old_signature = exported_program.graph_signature
21+
copied_nodes = set()
22+
for idx, node in enumerate(exported_program.graph.nodes):
23+
if node.op != "placeholder":
24+
continue
25+
old_input_spec = old_signature.input_specs[idx]
26+
new_input_specs.append(
27+
InputSpec(
28+
old_input_spec.kind,
29+
old_input_spec.arg,
30+
old_input_spec.target,
31+
persistent=old_input_spec.persistent,
32+
)
33+
)
34+
if node == to_be_copied[0]:
35+
constant_tensor = node
36+
users = list(node.users.keys())
37+
for ith in range(len(node.users) - 1):
38+
copy_constant_tensor_fqn = node.name + f"_copy_{ith}"
39+
with exported_program.graph.inserting_before(constant_tensor):
40+
copied_constant_tensor = exported_program.graph.placeholder(
41+
copy_constant_tensor_fqn
42+
)
43+
copied_nodes.add(copy_constant_tensor_fqn)
44+
for k, v in node.meta.items():
45+
copied_constant_tensor.meta[k] = v
46+
copied_constant_tensor.meta["val"] = constant_tensor.meta["val"]
47+
new_args = tuple(
48+
[
49+
arg if arg != constant_tensor else copied_constant_tensor
50+
for arg in users[ith + 1].args
51+
]
52+
)
53+
new_kwargs = dict(
54+
{
55+
(
56+
key,
57+
(
58+
value
59+
if value != constant_tensor
60+
else copied_constant_tensor
61+
),
62+
)
63+
for key, value in users[ith + 1].kwargs
64+
}
65+
)
66+
users[ith + 1].args = new_args
67+
users[ith + 1].kwargs = new_kwargs
68+
exported_program.state_dict[copy_constant_tensor_fqn] = (
69+
copied_constant_tensor
70+
)
71+
new_input_specs.append(
72+
InputSpec(
73+
kind=old_input_spec.kind,
74+
arg=TensorArgument(name=copy_constant_tensor_fqn),
75+
target=old_input_spec.target,
76+
persistent=old_input_spec.persistent,
77+
)
78+
)
79+
80+
exported_program.graph_signature.input_specs = new_input_specs
81+
exported_program.graph_module.recompile()
82+
83+
return copied_nodes

exir/backend/test/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ python_library(
6767
"//caffe2:torch",
6868
"//executorch/exir:lib",
6969
"//executorch/exir/backend:partitioner",
70+
"//executorch/exir/backend:utils",
7071
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
7172
],
7273
)
@@ -283,6 +284,9 @@ python_unittest(
283284
srcs = [
284285
"test_partitioner.py",
285286
],
287+
preload_deps = [
288+
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
289+
],
286290
deps = [
287291
"//caffe2:torch",
288292
"//executorch/exir:lib",
@@ -295,6 +299,19 @@ python_unittest(
295299
"//executorch/exir/dialects:lib",
296300
"//executorch/exir/tests:models",
297301
"//executorch/extension/pybindings:portable_lib", # @manual
302+
"//executorch/extension/pytree:pylib",
298303
"//executorch/runtime/executor/test:test_backend_compiler_lib",
299304
],
300305
)
306+
307+
python_unittest(
308+
name = "test_passes",
309+
srcs = [
310+
"test_passes.py",
311+
],
312+
deps = [
313+
"//caffe2:torch",
314+
"//executorch/exir:lib",
315+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
316+
],
317+
)

exir/backend/test/demos/rpc/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def define_common_targets():
3939
srcs = [
4040
"ExecutorBackendRegister.cpp",
4141
],
42+
visibility = [
43+
"//executorch/exir/backend/test/...",
44+
],
4245
deps = [
4346
":executor_backend",
4447
"//executorch/runtime/backend:interface",

exir/backend/test/test_partitioner.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

3333
from executorch.exir.tests.models import MLP
34+
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
35+
_load_for_executorch_from_buffer,
36+
)
37+
from executorch.extension.pytree import tree_flatten
3438
from torch._export import capture_pre_autograd_graph
3539
from torch._export.utils import is_buffer, is_param
3640
from torch.export import export
@@ -446,6 +450,66 @@ def partition(
446450
partition_tags=partition_tags,
447451
)
448452

453+
inputs = (torch.ones(2, 2),)
454+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
455+
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
456+
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
457+
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
458+
inputs_flattened, _ = tree_flatten(inputs)
459+
460+
# Send the input from server executor to client executor, and receive the result from client executor
461+
_ = executorch_module.run_method("forward", inputs)
462+
463+
def test_partitioner_alert_split_constant_data(self):
464+
"""
465+
We test that we throw an error when constant data users are split
466+
between different delegated payloads or owning program.
467+
"""
468+
469+
class ReuseConstData(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.const = torch.ones(2, 2)
473+
474+
def forward(self, x):
475+
y = x + self.const
476+
z = x - self.const
477+
return y, z
478+
479+
class PartitionerTagData(Partitioner):
480+
def __init__(self):
481+
super().__init__()
482+
self.delegation_spec = DelegationSpec(
483+
ExecutorBackend.__name__,
484+
[CompileSpec(key, value) for key, value in self.spec.items()],
485+
)
486+
487+
def partition(
488+
self, edge_exported_program: ExportedProgram
489+
) -> PartitionResult:
490+
partition_tags = {}
491+
for node in edge_exported_program.graph.nodes:
492+
if node.op == "call_function" and node.target in [
493+
exir_ops.edge.aten.add.Tensor
494+
]:
495+
delegation_tag = "tag0"
496+
node.meta["delegation_tag"] = delegation_tag
497+
partition_tags[delegation_tag] = self.delegation_spec
498+
499+
if node.op == "placeholder" and (
500+
is_param(edge_exported_program, node)
501+
or is_buffer(edge_exported_program, node)
502+
):
503+
delegation_tag = "tag0"
504+
node.meta["delegation_tag"] = delegation_tag
505+
node.meta["no_copy"] = True
506+
partition_tags[delegation_tag] = self.delegation_spec
507+
508+
return PartitionResult(
509+
tagged_exported_program=edge_exported_program,
510+
partition_tags=partition_tags,
511+
)
512+
449513
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
450514
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
451515
with self.assertRaises(RuntimeError) as error:

exir/backend/test/test_passes.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
3+
import torch
4+
from executorch import exir
5+
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
6+
duplicate_constant_node,
7+
)
8+
from torch._export import capture_pre_autograd_graph
9+
from torch._export.utils import is_buffer
10+
from torch.testing import FileCheck
11+
12+
13+
class TestPaases(unittest.TestCase):
14+
def test_duplicate_constant_node_pass(self):
15+
16+
class ReuseConstData(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.register_buffer("const", torch.ones(2, 2))
20+
21+
def forward(self, x):
22+
y = x + self.const
23+
z = x - self.const
24+
return y, z
25+
26+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
27+
edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))
28+
29+
const_nodes = [
30+
node.name
31+
for node in edge.exported_program().graph.nodes
32+
if node.op == "placeholder" and is_buffer(edge.exported_program(), node)
33+
]
34+
35+
copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0])
36+
self.assertEqual(len(copied_nodes), 1)
37+
38+
# Check that the new constant node is in the graph
39+
FileCheck().check("arg0_1_copy_0").run(
40+
edge.exported_program().graph_module.code
41+
)

exir/backend/utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
import torch
1414
from executorch.exir.backend.backend_details import ExportedProgram
15+
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
16+
duplicate_constant_node,
17+
)
1518
from executorch.exir.common import setting_python_recursive_limit
1619
from executorch.exir.delegate import executorch_call_delegate
1720
from executorch.exir.dialects._ops import ops as exir_ops
@@ -174,6 +177,72 @@ def replace_quantized_partition_with_op(
174177
return (replaced_op, dequant_nodes, quant_nodes)
175178

176179

180+
def _assign_new_tag(
181+
tagged_exported_program: ExportedProgram,
182+
copied_nodes: Set[str],
183+
):
184+
"""
185+
Assign new tag to the copied nodes.
186+
187+
Before the pass
188+
consant_0 (tag_10) ------------------> op_b (tag_10)
189+
consant_0_copy (tag_10) -------------> op_a (tag_11)
190+
191+
After the pass
192+
consant_0 (tag_10) ------------------> op_b (tag_10)
193+
consant_0_copy (tag_11) -------------> op_a (tag_11)
194+
195+
"""
196+
for node in tagged_exported_program.graph.nodes:
197+
if node.op == "placeholder":
198+
if node.name in copied_nodes:
199+
users_tag = set()
200+
for user in node.users:
201+
users_tag.add(user.meta.get("delegation_tag", None))
202+
# Assign the tag to the copy constant node the same as their users.
203+
if len(users_tag) == 1:
204+
node.meta["delegation_tag"] = users_tag.pop()
205+
206+
207+
def _maybe_duplicate_constant_nodes(
208+
tagged_exported_program: ExportedProgram,
209+
tag: str,
210+
owning_program: ExportedProgram,
211+
) -> None:
212+
"""
213+
If the constants node is shared by different tagged nodes, like
214+
consant_0 ----> op_b (tag_10)
215+
|-------------> op_a (tag_11)
216+
217+
we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy"
218+
consant_0 ------------------> op_b (tag_10)
219+
consant_0_copy -------------> op_a (tag_11)
220+
221+
backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate
222+
"""
223+
candidate_nodes = set()
224+
for node in tagged_exported_program.graph.nodes:
225+
if node.meta.get("delegation_tag", "") == tag:
226+
if node.op == "placeholder":
227+
for user in node.users:
228+
users_tag = user.meta.get("delegation_tag", None)
229+
if users_tag != tag:
230+
# If the node is tagged with "no_copy", we stop duplicating it and throw an error
231+
if node.meta.get("no_copy", False):
232+
raise RuntimeError(
233+
f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
234+
)
235+
else:
236+
candidate_nodes.add(node.name)
237+
copied_nodes = set()
238+
for candidate_node in candidate_nodes:
239+
# Both tagged exported program and the owning program need to go through the same duplication pass
240+
copied_nodes = duplicate_constant_node(tagged_exported_program, candidate_node)
241+
duplicate_constant_node(owning_program, candidate_node)
242+
243+
_assign_new_tag(tagged_exported_program, copied_nodes)
244+
245+
177246
def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
178247
"""
179248
Check if the node is the getitem followed by executorch_call_delegate node. These getitems node

0 commit comments

Comments
 (0)