Skip to content

Commit c13983c

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix delegate cache duplicate bug
Summary: Reported by #7175 that the delegate is not deduplicate when they're exactly the same Differential Revision: D67067997
1 parent 8fc3f8c commit c13983c

File tree

5 files changed

+70
-7
lines changed

5 files changed

+70
-7
lines changed

exir/_serialize/_program.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def _extract_delegate_segments(
254254
"""
255255
remaining_inline: List[BackendDelegateInlineData] = []
256256
inline_indices_seen: set[int] = set()
257+
segment_index_map: dict[bytes, int] = {}
257258
for plan in program.execution_plan:
258259
for delegate in plan.delegates:
259260
if delegate.processed.location != DataLocation.INLINE:
@@ -279,8 +280,11 @@ def _extract_delegate_segments(
279280
inline_indices_seen.add(delegate.processed.index)
280281
if inline.data:
281282
# Move the delegate data out of the program.
282-
segment_index = len(segments)
283-
segments.append(Cord(inline.data))
283+
segment_index = segment_index_map.get(inline.data)
284+
if segment_index is None:
285+
segment_index = len(segments)
286+
segments.append(Cord(inline.data))
287+
segment_index_map[inline.data] = segment_index
284288
delegate.processed = BackendDelegateDataReference(
285289
location=DataLocation.SEGMENT,
286290
index=segment_index,

exir/backend/test/demos/rpc/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
],
2929
visibility = [
3030
"//executorch/exir/backend/test/...",
31+
"//executorch/exir/emit/test/...",
3132
],
3233
deps = [
3334
":executor_backend_preprocess",

exir/emit/_emitter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class _ProgramState:
122122
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123123
# and should be copied to Program.backend_delegate_data.
124124
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
125+
# Delegate cache that is used across all entry points.
126+
backend_delegate_data_cache: Dict[bytes, int] = field(default_factory=dict)
125127

126128
# Constants are optionally stored in external files.
127129
# Aggregate unique external constants into one buffer.
@@ -1112,10 +1114,13 @@ def _emit_delegate(
11121114
if delegate_index is None:
11131115
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11141116
# present.
1115-
data_index: int = len(self.program_state.backend_delegate_data)
1116-
self.program_state.backend_delegate_data.append(
1117-
BackendDelegateInlineData(data=processed_bytes)
1118-
)
1117+
data_index: Optional[int] = self.program_state.backend_delegate_data_cache.get(processed_bytes)
1118+
if data_index is None:
1119+
data_index = len(self.program_state.backend_delegate_data)
1120+
self.program_state.backend_delegate_data_cache[processed_bytes] = data_index
1121+
self.program_state.backend_delegate_data.append(
1122+
BackendDelegateInlineData(data=processed_bytes)
1123+
)
11191124

11201125
backend_delegate = BackendDelegate(
11211126
id=lowered_module.backend_id,

exir/emit/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_unittest(
1616
"//executorch/exir:lib",
1717
"//executorch/exir:print_program",
1818
"//executorch/exir:schema",
19+
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
1920
"//executorch/exir/backend:backend_api",
2021
"//executorch/exir/emit:lib",
2122
"//executorch/exir/passes:const_prop_pass",

exir/emit/test/test_emit.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from executorch.exir._serialize._program import deserialize_pte_binary
2727
from executorch.exir.backend.backend_api import to_backend
2828
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
29+
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
30+
ExecutorBackendPartitioner,
31+
)
2932
from executorch.exir.dialects._ops import ops as exir_ops
3033
from executorch.exir.emit import emit_program # noqa
3134
from executorch.exir.error import InternalError
@@ -60,7 +63,7 @@
6063
from functorch.experimental import control_flow
6164
from torch import nn
6265

63-
from torch.export import Dim, export
66+
from torch.export import Dim, export, export_for_training
6467

6568

6669
class WrapperModule(torch.nn.Module):
@@ -1660,3 +1663,52 @@ def forward(self, x):
16601663
]
16611664
self.assertEqual(external_map["linear.weight"], 0)
16621665
self.assertEqual(external_map["linear.bias"], 1)
1666+
1667+
def test_delegate_deduplicate(self) -> None:
1668+
class SharedModule(torch.nn.Module):
1669+
def __init__(self):
1670+
super().__init__()
1671+
self.linear = torch.nn.Linear(2, 2)
1672+
1673+
def forward(self, x):
1674+
return self.linear(x)
1675+
1676+
1677+
class Module1(torch.nn.Module):
1678+
def __init__(self, shared_module):
1679+
super().__init__()
1680+
self.shared_module = shared_module
1681+
1682+
def forward(self, x):
1683+
return self.shared_module(x)
1684+
1685+
1686+
class Module2(torch.nn.Module):
1687+
def __init__(self, shared_module):
1688+
super().__init__()
1689+
self.shared_module = shared_module
1690+
1691+
def forward(self, x):
1692+
return self.shared_module(x)
1693+
1694+
shared_module = SharedModule()
1695+
module_1 = Module1(shared_module)
1696+
module_2 = Module2(shared_module)
1697+
example_inputs = (torch.randn(2, 2),)
1698+
module_1(*example_inputs)
1699+
module_2(*example_inputs)
1700+
1701+
ep1 = export_for_training(module_1, example_inputs)
1702+
ep2 = export_for_training(module_2, example_inputs)
1703+
1704+
edge_program_manager = exir.to_edge(
1705+
{"forward1": ep1, "forward2": ep2},
1706+
compile_config=exir.EdgeCompileConfig(
1707+
_check_ir_validity=False, _use_edge_ops=True
1708+
),
1709+
)
1710+
1711+
edge_program_manager = edge_program_manager.to_backend(ExecutorBackendPartitioner()).to_executorch()
1712+
1713+
# Check that there is only one delegate because two methods are exactly the same
1714+
self.assertEqual(len(edge_program_manager.executorch_program.backend_delegate_data), 1)

0 commit comments

Comments
 (0)