Skip to content

Commit 76b2170

Browse files
cccclaifacebook-github-bot
authored andcommitted
add duplicate constant node pass (#2570)
Summary: This diff is the follow up for #2424 In the case like ``` consant_0 (tag_10) ----> op_b (tag_10) |-------------> op_a (tag_11) ``` `op_b` and `op_a` are in two delegated payload and `constant_0` have two options: In this diff, we're making the default behavior as allowing copying, meaning it will become ``` consant_0 (tag_10)------------------> op_b (tag_10) consant_0_copy (tag_11) -------------> op_a (tag_11) ``` The backend can tag the node with `no_copy` to allert users in cases like constants are too large or etc. In this case, a better approach can be ``` consant_0 (tag_10) ----> op_b (tag_10) |-----(output consant_0) --------> op_a (tag_11) ``` Reviewed By: kirklandsign Differential Revision: D55113232
1 parent ec6b88a commit 76b2170

File tree

9 files changed

+331
-1
lines changed

9 files changed

+331
-1
lines changed

exir/backend/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ runtime.python_library(
2626
":compile_spec_schema",
2727
"//caffe2:torch",
2828
"//executorch/exir/backend:utils",
29+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
2930
],
3031
)
3132

exir/backend/backend_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717

1818
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
19-
from executorch.exir.backend.utils import is_identical_graph
19+
from executorch.exir.backend.utils import (
20+
_maybe_duplicate_constant_nodes,
21+
is_identical_graph,
22+
)
2023

2124
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
2225

@@ -160,6 +163,7 @@ def _get_node_list_with_same_tag(
160163
Return a list of nodes with the same tag.
161164
"""
162165
node_list = []
166+
163167
for node in tagged_graph_module.graph.nodes:
164168
if node.meta.get("delegation_tag", "") == tag:
165169
if node.op == "output":
@@ -373,6 +377,10 @@ def to_backend(
373377
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
374378

375379
update_to_real_program(tagged_exported_program, edge_program)
380+
381+
for tag, _ in partitioner_result.partition_tags.items():
382+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
383+
376384
tagged_graph_module = _partition_and_lower(
377385
tagged_exported_program.graph_module, partitioner_result, edge_program
378386
)

exir/backend/canonical_partitioners/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ runtime.python_library(
1919
"//executorch/exir/backend:partitioner",
2020
],
2121
)
22+
23+
runtime.python_library(
24+
name = "duplicate_constant_node_pass",
25+
srcs = [
26+
"duplicate_constant_node_pass.py",
27+
],
28+
visibility = [
29+
"//executorch/...",
30+
"//executorch/exir/backend/...",
31+
"//executorch/test/...",
32+
"@EXECUTORCH_CLIENTS",
33+
],
34+
deps = [
35+
"//caffe2:torch",
36+
"//executorch/exir/backend:partitioner",
37+
],
38+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Set
9+
10+
from torch.export import ExportedProgram
11+
from torch.export.exported_program import InputSpec, TensorArgument
12+
13+
14+
def duplicate_constant_node(
15+
exported_program: ExportedProgram, candidate_node: str
16+
) -> Set[str]:
17+
"""
18+
A pass to duplicate a constant node (the candidate_node) in the graph. Mostly used for duplicating light-weight constants.
19+
If the constants are too expensive, try tag it with "no_copy" to prevent high memory usage and make it as part of the output.
20+
21+
Args:
22+
exported_program: the exported program to be modified. If constants nodes are copied, they will be added as new
23+
placeholder and the state_dict will be updated
24+
candidate_node: the name of the constant node to be duplicated
25+
26+
Returns:
27+
The set of the names of the new constant nodes
28+
"""
29+
to_be_copied = [
30+
node
31+
for node in exported_program.graph.nodes
32+
if node.name == candidate_node and node.op == "placeholder"
33+
]
34+
if len(to_be_copied) == 0:
35+
logging.info("no constant node to be copied")
36+
return set()
37+
new_input_specs = []
38+
old_signature = exported_program.graph_signature
39+
copied_nodes = set()
40+
for idx, node in enumerate(exported_program.graph.nodes):
41+
if node.op != "placeholder":
42+
continue
43+
old_input_spec = old_signature.input_specs[idx]
44+
new_input_specs.append(
45+
InputSpec(
46+
old_input_spec.kind,
47+
old_input_spec.arg,
48+
old_input_spec.target,
49+
persistent=old_input_spec.persistent,
50+
)
51+
)
52+
if node == to_be_copied[0]:
53+
constant_tensor = node
54+
users = list(node.users.keys())
55+
for ith in range(len(node.users) - 1):
56+
copy_constant_tensor_fqn = node.name + f"_copy_{ith}"
57+
with exported_program.graph.inserting_before(constant_tensor):
58+
copied_constant_tensor = exported_program.graph.placeholder(
59+
copy_constant_tensor_fqn
60+
)
61+
copied_nodes.add(copy_constant_tensor_fqn)
62+
logging.info(
63+
f"Copying constant nodes {node.name} and creating {copy_constant_tensor_fqn}"
64+
)
65+
for k, v in node.meta.items():
66+
copied_constant_tensor.meta[k] = v
67+
copied_constant_tensor.meta["val"] = constant_tensor.meta["val"]
68+
new_args = tuple(
69+
[
70+
arg if arg != constant_tensor else copied_constant_tensor
71+
for arg in users[ith + 1].args
72+
]
73+
)
74+
new_kwargs = dict(
75+
{
76+
(
77+
key,
78+
(
79+
value
80+
if value != constant_tensor
81+
else copied_constant_tensor
82+
),
83+
)
84+
for key, value in users[ith + 1].kwargs
85+
}
86+
)
87+
users[ith + 1].args = new_args
88+
users[ith + 1].kwargs = new_kwargs
89+
exported_program.state_dict[copy_constant_tensor_fqn] = (
90+
copied_constant_tensor
91+
)
92+
new_input_specs.append(
93+
InputSpec(
94+
kind=old_input_spec.kind,
95+
arg=TensorArgument(name=copy_constant_tensor_fqn),
96+
target=old_input_spec.target,
97+
persistent=old_input_spec.persistent,
98+
)
99+
)
100+
101+
exported_program.graph_signature.input_specs = new_input_specs
102+
exported_program.graph_module.recompile()
103+
104+
return copied_nodes

exir/backend/test/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ python_library(
6767
"//caffe2:torch",
6868
"//executorch/exir:lib",
6969
"//executorch/exir/backend:partitioner",
70+
"//executorch/exir/backend:utils",
7071
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
7172
],
7273
)
@@ -283,6 +284,9 @@ python_unittest(
283284
srcs = [
284285
"test_partitioner.py",
285286
],
287+
preload_deps = [
288+
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
289+
],
286290
deps = [
287291
"//caffe2:torch",
288292
"//executorch/exir:lib",
@@ -295,6 +299,19 @@ python_unittest(
295299
"//executorch/exir/dialects:lib",
296300
"//executorch/exir/tests:models",
297301
"//executorch/extension/pybindings:portable_lib", # @manual
302+
"//executorch/extension/pytree:pylib",
298303
"//executorch/runtime/executor/test:test_backend_compiler_lib",
299304
],
300305
)
306+
307+
python_unittest(
308+
name = "test_passes",
309+
srcs = [
310+
"test_passes.py",
311+
],
312+
deps = [
313+
"//caffe2:torch",
314+
"//executorch/exir:lib",
315+
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
316+
],
317+
)

exir/backend/test/demos/rpc/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def define_common_targets():
3939
srcs = [
4040
"ExecutorBackendRegister.cpp",
4141
],
42+
visibility = [
43+
"//executorch/exir/backend/test/...",
44+
],
4245
deps = [
4346
":executor_backend",
4447
"//executorch/runtime/backend:interface",

exir/backend/test/test_partitioner.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

3333
from executorch.exir.tests.models import MLP
34+
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
35+
_load_for_executorch_from_buffer,
36+
)
37+
from executorch.extension.pytree import tree_flatten
3438
from torch._export import capture_pre_autograd_graph
3539
from torch._export.utils import is_buffer, is_param
3640
from torch.export import export
@@ -446,6 +450,66 @@ def partition(
446450
partition_tags=partition_tags,
447451
)
448452

453+
inputs = (torch.ones(2, 2),)
454+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
455+
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
456+
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
457+
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
458+
inputs_flattened, _ = tree_flatten(inputs)
459+
460+
# Send the input from server executor to client executor, and receive the result from client executor
461+
_ = executorch_module.run_method("forward", inputs)
462+
463+
def test_partitioner_alert_split_constant_data(self):
464+
"""
465+
We test that we throw an error when constant data users are split
466+
between different delegated payloads or owning program.
467+
"""
468+
469+
class ReuseConstData(torch.nn.Module):
470+
def __init__(self):
471+
super().__init__()
472+
self.const = torch.ones(2, 2)
473+
474+
def forward(self, x):
475+
y = x + self.const
476+
z = x - self.const
477+
return y, z
478+
479+
class PartitionerTagData(Partitioner):
480+
def __init__(self):
481+
super().__init__()
482+
self.delegation_spec = DelegationSpec(
483+
ExecutorBackend.__name__,
484+
[CompileSpec(key, value) for key, value in self.spec.items()],
485+
)
486+
487+
def partition(
488+
self, edge_exported_program: ExportedProgram
489+
) -> PartitionResult:
490+
partition_tags = {}
491+
for node in edge_exported_program.graph.nodes:
492+
if node.op == "call_function" and node.target in [
493+
exir_ops.edge.aten.add.Tensor
494+
]:
495+
delegation_tag = "tag0"
496+
node.meta["delegation_tag"] = delegation_tag
497+
partition_tags[delegation_tag] = self.delegation_spec
498+
499+
if node.op == "placeholder" and (
500+
is_param(edge_exported_program, node)
501+
or is_buffer(edge_exported_program, node)
502+
):
503+
delegation_tag = "tag0"
504+
node.meta["delegation_tag"] = delegation_tag
505+
node.meta["no_copy"] = True
506+
partition_tags[delegation_tag] = self.delegation_spec
507+
508+
return PartitionResult(
509+
tagged_exported_program=edge_exported_program,
510+
partition_tags=partition_tags,
511+
)
512+
449513
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
450514
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
451515
with self.assertRaises(RuntimeError) as error:

exir/backend/test/test_passes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch import exir
11+
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
12+
duplicate_constant_node,
13+
)
14+
from torch._export import capture_pre_autograd_graph
15+
from torch._export.utils import is_buffer
16+
from torch.testing import FileCheck
17+
18+
19+
class TestPaases(unittest.TestCase):
20+
def test_duplicate_constant_node_pass(self):
21+
22+
class ReuseConstData(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.register_buffer("const", torch.ones(2, 2))
26+
27+
def forward(self, x):
28+
y = x + self.const
29+
z = x - self.const
30+
return y, z
31+
32+
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
33+
edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))
34+
35+
const_nodes = [
36+
node.name
37+
for node in edge.exported_program().graph.nodes
38+
if node.op == "placeholder" and is_buffer(edge.exported_program(), node)
39+
]
40+
41+
copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0])
42+
self.assertEqual(len(copied_nodes), 1)
43+
44+
# Check that the new constant node is in the graph
45+
FileCheck().check("arg0_1_copy_0").run(
46+
edge.exported_program().graph_module.code
47+
)

0 commit comments

Comments
 (0)