Skip to content

Commit 0268d46

Browse files
committed
Introduce NamedDataStore
Introduce NamedDataStore for weight sharing. Rename 'NamedBlobStore' --> 'NamedDataStore' to mirror 'NamedDataMap' in the runtime. The NamedDataStore exposes two methods: - add_named_data: add a blob to the store - get_named_data_store_output: return the contents of the store, to pass to serialization. Invariants on the NamedDataStore - Keys are unique regardless of whether they are in PTE or external file. - Different keys can point to the same data. NamedDataStore is used in D69764150. It's owned by the EdgeProgramManager. Differential Revision: [D69764094](https://our.internmc.facebook.com/intern/diff/D69764094/) [ghstack-poisoned]
1 parent cc3974f commit 0268d46

File tree

4 files changed

+274
-0
lines changed

4 files changed

+274
-0
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ runtime.python_library(
3232
"_cord.py",
3333
"_dataclass.py",
3434
"_flatbuffer.py",
35+
"_named_data_store.py",
3536
"_program.py",
3637
"_serialize.py",
3738
"data_serializer.py",

exir/_serialize/_named_data_store.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import hashlib
10+
from dataclasses import dataclass
11+
12+
# from dataclasses import dataclass
13+
from typing import Dict, List, Optional
14+
15+
16+
def gcd(a: int, b: int) -> int:
17+
while b:
18+
a, b = b, a % b
19+
return a
20+
21+
22+
def lcm(a: int, b: int) -> int:
23+
return (a * b) // gcd(a, b)
24+
25+
26+
@dataclass
27+
class BufferEntry:
28+
"""A class to hold the buffer entries for serialization.
29+
30+
Attributes:
31+
buffer: The buffer bytes.
32+
alignment: The alignment of the buffer.
33+
"""
34+
35+
buffer: bytes
36+
alignment: int
37+
38+
39+
@dataclass
40+
class NamedDataStoreOutput:
41+
"""
42+
A class to hold the named data for serialization.
43+
44+
Attributes:
45+
buffer: A list of unique buffer entries.
46+
pte_data: Contains data that is stored inside the PTE file. A mapping from
47+
{key: buffer_index}.
48+
external_data: Contains data that is stored external to the PTE. A mapping
49+
from {filename: {key: buffer_index}}.
50+
"""
51+
52+
buffers: List[BufferEntry]
53+
pte_data: Dict[str, int]
54+
external_data: Dict[str, Dict[str, int]]
55+
56+
57+
class NamedDataStore:
58+
"""
59+
NamedDataStore manages the data that delegates want to share. Backends add
60+
bytes to the store under a unique key. These bytes can be retrieved at
61+
runtime using the same key with the NamedDataMap.
62+
63+
Note:
64+
- Keys are unique in the data store, regardless of whether they are stored
65+
in the PTE or externally.
66+
- Multiple keys can point to the same buffer entry.
67+
- The same data can be added multiple times; all keys will point to one
68+
buffer. If a duplicate blob is added with a different alignment, the
69+
lcm of the current and new alignment is taken for that blob.
70+
"""
71+
72+
# List of unique blobs.
73+
buffers: List[BufferEntry]
74+
# Named data stored inside the PTE file. Map of {key: buffer_index}.
75+
pte_data: Dict[str, int]
76+
# Named data stored outside of the PTE file.
77+
# Map of {filename: {key: buffer_index}}.
78+
external_data: Dict[str, Dict[str, int]]
79+
80+
# Cache of the data hash for deduplication.
81+
data_cache: Dict[str, int]
82+
# Cache of the keys to ensure uniqueness.
83+
key_cache: Dict[str, int]
84+
85+
def __init__(self) -> None:
86+
"""
87+
Initializes a new NamedDataStore.
88+
"""
89+
self.buffers = []
90+
self.pte_data = {}
91+
self.external_data = {}
92+
93+
self.data_cache = {}
94+
self.key_cache = {}
95+
96+
def _add_named_data_to_map(
97+
self, key: str, data: bytes, alignment: int, map: Dict[str, int]
98+
) -> None:
99+
"""
100+
Add data to a map and update the alignment. Ensure that the key-data
101+
pair is unique.
102+
- If the key exists, the data must be identical.
103+
- If multiple unique keys exist for the same data, those keys should
104+
point to the same buffer.
105+
106+
Args:
107+
key (str): key associated with the data.
108+
data (bytes): Bytes being requested to be serialized.
109+
alignment (int): alignment for bytes to be serialized with.
110+
map (Dict[str, int]): map to add the data to.
111+
Raises:
112+
ValueError: when the key exists in the store, and corresponding data
113+
is different.
114+
"""
115+
# Check if the key exists.
116+
buffer_idx = self.key_cache.get(key, -1)
117+
if buffer_idx != -1:
118+
# If the key exists, the corresponding data must be identical.
119+
if self.buffers[buffer_idx].buffer != data:
120+
raise ValueError(f"Duplicate key {key} with different data.")
121+
self.buffers[buffer_idx].alignment = lcm(
122+
self.buffers[buffer_idx].alignment, alignment
123+
)
124+
else:
125+
# Key doesn't exist; check if the data exists.
126+
hashed = hashlib.sha256(data).hexdigest()
127+
buffer_idx = self.data_cache.get(hashed, -1)
128+
if buffer_idx != -1:
129+
# The data exists; update the alignment.
130+
self.buffers[buffer_idx].alignment = lcm(
131+
self.buffers[buffer_idx].alignment, alignment
132+
)
133+
else:
134+
# The data doesn't exist; add it to the data store.
135+
buffer_idx = len(self.buffers)
136+
self.buffers.append(BufferEntry(data, alignment))
137+
self.data_cache[hashed] = buffer_idx
138+
139+
# Add key to the map and the key cache.
140+
map[key] = buffer_idx
141+
self.key_cache[key] = buffer_idx
142+
143+
def add_named_data(
144+
self,
145+
key: str,
146+
data: bytes,
147+
alignment: Optional[int] = 1,
148+
external_tag: Optional[str] = None,
149+
) -> None:
150+
"""
151+
Adds a named blob to the NamedDataStore.
152+
Args:
153+
key (str): key associated with the data.
154+
data (bytes): Bytes being requested to be serialized.
155+
alignment (int): alignment for bytes to be serialized with.
156+
external (Optional[str]): the external filename that this data is saved to.
157+
Raises:
158+
ValueError: when the key exists in the store, and corresponding data
159+
is different.
160+
"""
161+
162+
# Set default alignment.
163+
if alignment is None:
164+
alignment = 1
165+
166+
if external_tag is None:
167+
self._add_named_data_to_map(key, data, alignment, self.pte_data)
168+
else:
169+
if self.external_data.get(external_tag, None) is None:
170+
self.external_data[external_tag] = {}
171+
self._add_named_data_to_map(
172+
key, data, alignment, self.external_data[external_tag]
173+
)
174+
175+
def get_named_data_store_output(self) -> NamedDataStoreOutput:
176+
# Clean up empty maps inside self.external_data
177+
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
178+
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)

exir/_serialize/test/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ python_unittest(
3333
"//executorch/exir/_serialize:lib",
3434
],
3535
)
36+
37+
python_unittest(
38+
name = "named_data_store",
39+
srcs = [
40+
"test_named_data_store.py",
41+
],
42+
deps = [
43+
"//executorch/exir/_serialize:lib",
44+
],
45+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore
12+
13+
14+
class TestNamedDataStore(unittest.TestCase):
15+
def test_add(self) -> None:
16+
store = NamedDataStore()
17+
store.add_named_data("key1", b"data1", None, None)
18+
store.add_named_data("key2", b"data2", 16, "file1")
19+
store.add_named_data("key3", b"data3", 16, "file1")
20+
21+
output = store.get_named_data_store_output()
22+
23+
self.assertEqual(len(output.buffers), 3)
24+
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
25+
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
26+
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))
27+
28+
self.assertEqual(len(output.pte_data), 1)
29+
self.assertEqual(output.pte_data["key1"], 0)
30+
31+
self.assertEqual(len(output.external_data), 1)
32+
self.assertEqual(len(output.external_data["file1"]), 2)
33+
self.assertEqual(output.external_data["file1"]["key2"], 1)
34+
self.assertEqual(output.external_data["file1"]["key3"], 2)
35+
36+
def test_add_duplicate_name_and_data(self) -> None:
37+
store = NamedDataStore()
38+
store.add_named_data("key", b"data", None, None)
39+
store.add_named_data("key", b"data", None, None)
40+
41+
output = store.get_named_data_store_output()
42+
43+
self.assertEqual(len(output.buffers), 1)
44+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
45+
46+
self.assertEqual(len(output.pte_data), 1)
47+
self.assertEqual(output.pte_data["key"], 0)
48+
49+
self.assertEqual(len(output.external_data), 0)
50+
51+
def test_add_same_data_with_different_alignment(self) -> None:
52+
store = NamedDataStore()
53+
store.add_named_data("key", b"data", 3, None)
54+
store.add_named_data("key1", b"data", 4, None)
55+
56+
output = store.get_named_data_store_output()
57+
58+
self.assertEqual(len(output.buffers), 1)
59+
# Check that we take the LCM of the two alignments (3, 4) = 12
60+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))
61+
62+
self.assertEqual(len(output.pte_data), 2)
63+
self.assertEqual(output.pte_data["key"], 0)
64+
self.assertEqual(output.pte_data["key1"], 0)
65+
66+
self.assertEqual(len(output.external_data), 0)
67+
68+
def test_add_duplicate_key_fail(self) -> None:
69+
store = NamedDataStore()
70+
store.add_named_data("key", b"data", None, None)
71+
72+
# Cannot add item with the same key and different data.
73+
self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None)
74+
self.assertRaises(
75+
ValueError, store.add_named_data, "key", b"data1", 16, "file1"
76+
)
77+
78+
output = store.get_named_data_store_output()
79+
80+
self.assertEqual(len(output.buffers), 1)
81+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
82+
83+
self.assertEqual(len(output.pte_data), 1)
84+
self.assertEqual(output.pte_data["key"], 0)
85+
self.assertEqual(len(output.external_data), 0)

0 commit comments

Comments
 (0)