Skip to content

add duplicate constant node pass #2570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ runtime.python_library(
":compile_spec_schema",
"//caffe2:torch",
"//executorch/exir/backend:utils",
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
],
)

Expand Down
10 changes: 9 additions & 1 deletion exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec

from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.backend.utils import is_identical_graph
from executorch.exir.backend.utils import (
_maybe_duplicate_constant_nodes,
is_identical_graph,
)

from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name

Expand Down Expand Up @@ -160,6 +163,7 @@ def _get_node_list_with_same_tag(
Return a list of nodes with the same tag.
"""
node_list = []

for node in tagged_graph_module.graph.nodes:
if node.meta.get("delegation_tag", "") == tag:
if node.op == "output":
Expand Down Expand Up @@ -373,6 +377,10 @@ def to_backend(
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"

update_to_real_program(tagged_exported_program, edge_program)

for tag, _ in partitioner_result.partition_tags.items():
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)

tagged_graph_module = _partition_and_lower(
tagged_exported_program.graph_module, partitioner_result, edge_program
)
Expand Down
17 changes: 17 additions & 0 deletions exir/backend/canonical_partitioners/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,20 @@ runtime.python_library(
"//executorch/exir/backend:partitioner",
],
)

runtime.python_library(
name = "duplicate_constant_node_pass",
srcs = [
"duplicate_constant_node_pass.py",
],
visibility = [
"//executorch/...",
"//executorch/exir/backend/...",
"//executorch/test/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
"//executorch/exir/backend:partitioner",
],
)
152 changes: 152 additions & 0 deletions exir/backend/canonical_partitioners/duplicate_constant_node_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from typing import Optional, Set

import torch
from torch._export.utils import get_buffer, get_lifted_tensor_constant, get_param

from torch.export import ExportedProgram
from torch.export.exported_program import InputSpec, TensorArgument
from torch.export.graph_signature import InputKind


def _get_attribute_or_constants(
exported_program: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
# get either attribute node or constant constant
maybe_param = get_param(exported_program, node)
maybe_buffer = get_buffer(exported_program, node)
maybe_lifted_tensor = get_lifted_tensor_constant(exported_program, node)

constant_or_attribute = None
if maybe_param is not None:
constant_or_attribute = maybe_param
elif maybe_buffer is not None:
constant_or_attribute = maybe_buffer
elif maybe_lifted_tensor is not None:
constant_or_attribute = maybe_lifted_tensor
return constant_or_attribute


# TODO: add other passes to duplicate call_function nodes
def duplicate_constant_node(
exported_program: ExportedProgram, candidate_node: str
) -> Set[str]:
"""
A pass to duplicate the attributes/constants node (the candidate_node) in the graph. Mostly used for duplicating light-weight data.
If the data is too large, try tag it with "no_copy" to prevent high memory usage and make it as part of the output.

Args:
exported_program: the exported program to be modified. If constants nodes are copied, they will be added as new
placeholder and the state_dict will be updated
candidate_node: the name of the constant node to be duplicated

Returns:
The set of the names of the new constant nodes
"""
to_be_copied = [
node
for node in exported_program.graph.nodes
if node.name == candidate_node and node.op == "placeholder"
]
if len(to_be_copied) == 0:
logging.info("no constant node to be copied")
return set()
new_input_specs = []
old_signature = exported_program.graph_signature
copied_nodes = set()
for idx, node in enumerate(exported_program.graph.nodes):
if node.op != "placeholder":
continue
old_input_spec = old_signature.input_specs[idx]
old_input_spec_copy = copy.deepcopy(old_input_spec)
if node == to_be_copied[0]:
constant_or_attribute_node = node
constant_or_attribute = _get_attribute_or_constants(exported_program, node)
if constant_or_attribute is None:
raise RuntimeError(
f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} "
)
users = list(node.users.keys())
for ith in range(1, len(node.users)):
copy_constant_or_attribute_fqn = node.name + f"_copy_{ith - 1}"
with exported_program.graph.inserting_before(
constant_or_attribute_node
):
copied_constant_or_attribute_node = (
exported_program.graph.placeholder(
copy_constant_or_attribute_fqn
)
)
copied_nodes.add(copy_constant_or_attribute_fqn)
logging.info(
f"Copying constant nodes {node.name} and creating {copy_constant_or_attribute_fqn}"
)
for k, v in node.meta.items():
copied_constant_or_attribute_node.meta[k] = v
copied_constant_or_attribute_node.meta["val"] = (
constant_or_attribute_node.meta["val"]
)
new_args = tuple(
[
(
arg
if arg != constant_or_attribute_node
else copied_constant_or_attribute_node
)
for arg in users[ith].args
]
)
new_kwargs = dict(
{
(
key,
(
value
if value != constant_or_attribute_node
else copied_constant_or_attribute_node
),
)
for key, value in users[ith].kwargs
}
)
users[ith].args = new_args
users[ith].kwargs = new_kwargs
if old_input_spec.kind == InputKind.CONSTANT_TENSOR:
exported_program.constants[copy_constant_or_attribute_fqn] = (
copy.deepcopy(constant_or_attribute)
)
elif (
old_input_spec.kind == InputKind.BUFFER
and old_input_spec.persistent is False
):
# non persistent buffer will be in the .constants
exported_program.constants[copy_constant_or_attribute_fqn] = (
copy.deepcopy(constant_or_attribute)
)
else:
exported_program.state_dict[copy_constant_or_attribute_fqn] = (
copy.deepcopy(constant_or_attribute)
)
new_input_specs.append(
InputSpec(
kind=old_input_spec.kind,
arg=TensorArgument(name=copy_constant_or_attribute_fqn),
target=old_input_spec.target,
persistent=old_input_spec.persistent,
)
)
# Ensure we add the original input spec to the last one, because all the copied nodes
# are inserted before the candidate node.
new_input_specs.append(old_input_spec_copy)

exported_program.graph_signature.input_specs = new_input_specs
exported_program.graph_module.recompile()
exported_program._validate()
return copied_nodes
17 changes: 17 additions & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ python_library(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend:utils",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)
Expand Down Expand Up @@ -283,6 +284,9 @@ python_unittest(
srcs = [
"test_partitioner.py",
],
preload_deps = [
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
Expand All @@ -295,6 +299,19 @@ python_unittest(
"//executorch/exir/dialects:lib",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/runtime/executor/test:test_backend_compiler_lib",
],
)

python_unittest(
name = "test_passes",
srcs = [
"test_passes.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/backend/canonical_partitioners:duplicate_constant_node_pass",
],
)
3 changes: 3 additions & 0 deletions exir/backend/test/demos/rpc/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def define_common_targets():
srcs = [
"ExecutorBackendRegister.cpp",
],
visibility = [
"//executorch/exir/backend/test/...",
],
deps = [
":executor_backend",
"//executorch/runtime/backend:interface",
Expand Down
64 changes: 64 additions & 0 deletions exir/backend/test/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.tests.models import MLP
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
_load_for_executorch_from_buffer,
)
from executorch.extension.pytree import tree_flatten
from torch._export import capture_pre_autograd_graph
from torch._export.utils import is_buffer, is_param
from torch.export import export
Expand Down Expand Up @@ -446,6 +450,66 @@ def partition(
partition_tags=partition_tags,
)

inputs = (torch.ones(2, 2),)
model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
inputs_flattened, _ = tree_flatten(inputs)

# Send the input from server executor to client executor, and receive the result from client executor
_ = executorch_module.run_method("forward", inputs)

def test_partitioner_alert_split_constant_data(self):
"""
We test that we throw an error when constant data users are split
between different delegated payloads or owning program.
"""

class ReuseConstData(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.ones(2, 2)

def forward(self, x):
y = x + self.const
z = x - self.const
return y, z

class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)

def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

if node.op == "placeholder" and (
is_param(edge_exported_program, node)
or is_buffer(edge_exported_program, node)
):
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
node.meta["no_copy"] = True
partition_tags[delegation_tag] = self.delegation_spec

return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)

model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
with self.assertRaises(RuntimeError) as error:
Expand Down
47 changes: 47 additions & 0 deletions exir/backend/test/test_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch import exir
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
duplicate_constant_node,
)
from torch._export import capture_pre_autograd_graph
from torch._export.utils import is_buffer
from torch.testing import FileCheck


class TestPasses(unittest.TestCase):
def test_duplicate_constant_node_pass(self):

class ReuseConstData(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("const", torch.ones(2, 2))

def forward(self, x):
y = x + self.const
z = x - self.const
return y, z

model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),))
edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))

const_nodes = [
node.name
for node in edge.exported_program().graph.nodes
if node.op == "placeholder" and is_buffer(edge.exported_program(), node)
]

copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0])
self.assertEqual(len(copied_nodes), 1)

# Check that the new constant node is in the graph
FileCheck().check("arg0_1_copy_0").run(
edge.exported_program().graph_module.code
)
Loading