Skip to content

Commit c81ddb0

Browse files
committed
Introduce NamedDataStore
Pull Request resolved: #8587 Introduce NamedDataStore for weight sharing. See 'NamedBlobStore' in [RFC] Rename 'NamedBlobStore' --> 'NamedDataStore' to mirror 'NamedDataMap' in the runtime. The NamedDataStore exposes two methods: - add_named_data: add a blob to the store - get_named_data_store_output: return the contents of the store, to pass to serialization. Invariants on the NamedDataStore - Keys are unique regardless of whether they are in PTE or external file. - Different keys can point to the same data. NamedDataStore is used in D69764150. It's owned by the EdgeProgramManager. ghstack-source-id: 268328940 @exported-using-ghexport Differential Revision: [D69764094](https://our.internmc.facebook.com/intern/diff/D69764094/)
1 parent 6cb5c1a commit c81ddb0

File tree

4 files changed

+282
-3
lines changed

4 files changed

+282
-3
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ runtime.python_library(
3232
"_cord.py",
3333
"_dataclass.py",
3434
"_flatbuffer.py",
35+
"_named_data_store.py",
3536
"_program.py",
3637
"_serialize.py",
3738
"data_serializer.py",

exir/_serialize/_named_data_store.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import hashlib
10+
import math
11+
from dataclasses import dataclass
12+
13+
# from dataclasses import dataclass
14+
from typing import Dict, List, Optional
15+
16+
17+
@dataclass
18+
class BufferEntry:
19+
"""A class to hold the buffer entries for serialization.
20+
21+
Attributes:
22+
buffer: The buffer bytes.
23+
alignment: The alignment of the buffer.
24+
"""
25+
26+
buffer: bytes
27+
alignment: int
28+
29+
30+
@dataclass
31+
class NamedDataStoreOutput:
32+
"""
33+
Holds named data for serialization.
34+
35+
Attributes:
36+
buffers: A list of unique buffer entries.
37+
pte_data: Contains data that is stored inside the PTE file. A mapping from
38+
{key: buffer_index}.
39+
external_data: Contains data that is stored external to the PTE. A mapping
40+
from {filename: {key: buffer_index}}.
41+
"""
42+
43+
buffers: List[BufferEntry]
44+
pte_data: Dict[str, int]
45+
external_data: Dict[str, Dict[str, int]]
46+
47+
48+
class NamedDataStore:
49+
"""
50+
NamedDataStore manages the data that delegates want to share. Backends add
51+
bytes to the store under a unique key. These bytes can be retrieved at
52+
runtime using the same key with the NamedDataMap.
53+
54+
Note:
55+
- Keys are unique in the data store, regardless of whether they are stored
56+
in the PTE or externally.
57+
- Multiple keys can point to the same buffer entry.
58+
- The same data can be added multiple times and all keys will point to one
59+
buffer. If a duplicate blob is added with a different alignment, the
60+
lcm of the current and new alignment is taken for that blob.
61+
"""
62+
63+
# List of unique blobs.
64+
buffers: List[BufferEntry]
65+
# Named data stored inside the PTE file. Map of {key: buffer_index}.
66+
pte_data: Dict[str, int]
67+
# Named data stored outside of the PTE file.
68+
# Map of {filename: {key: buffer_index}}.
69+
external_data: Dict[str, Dict[str, int]]
70+
71+
# Cache of the data hash for deduplication.
72+
# Use a hash instead of the data as a key because a sha256 collision is
73+
# unlikely, and the data may be large.
74+
data_hash_to_buffer_idx: Dict[bytes, int]
75+
# Cache of the key to buffer idx to ensure uniqueness.
76+
# If a key is added multiple times, check the buffer idx to ensure that the
77+
# data is identical too.
78+
key_to_buffer_idx: Dict[str, int]
79+
80+
def __init__(self) -> None:
81+
"""
82+
Initializes a new NamedDataStore.
83+
"""
84+
self.buffers = []
85+
self.pte_data = {}
86+
self.external_data = {}
87+
88+
self.data_hash_to_buffer_idx = {}
89+
self.key_to_buffer_idx = {}
90+
91+
def _add_named_data_to_map(
92+
self,
93+
key: str,
94+
data: bytes,
95+
alignment: int,
96+
local_key_to_buffer_idx: Dict[str, int],
97+
) -> None:
98+
"""
99+
Add data to a map and update the alignment. Ensure that the key-data
100+
pair is unique.
101+
- If the key exists, the data must be identical.
102+
- If multiple unique keys exist for the same data, those keys should
103+
point to the same buffer.
104+
105+
Args:
106+
key (str): key associated with the data.
107+
data (bytes): Bytes being requested to be serialized.
108+
alignment (int): alignment for bytes to be serialized with.
109+
local_key_to_buffer_idx (Dict[str, int]): map to add the data to.
110+
Raises:
111+
ValueError: when the key exists in the store, and corresponding data
112+
is different.
113+
"""
114+
# Get data hash.
115+
hashed = hashlib.sha256(data).digest()
116+
117+
# Check if the key exists.
118+
buffer_idx = self.key_to_buffer_idx.get(key, -1)
119+
if buffer_idx != -1:
120+
# If the key exists, the corresponding data must be identical.
121+
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
122+
raise ValueError(
123+
f"Duplicate key {key} with different data. "
124+
f"Existing data: {self.buffers[buffer_idx].buffer}. "
125+
f"New data: {data}."
126+
)
127+
self.buffers[buffer_idx].alignment = math.lcm(
128+
self.buffers[buffer_idx].alignment, alignment
129+
)
130+
else:
131+
# Key doesn't exist; check if the data exists.
132+
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
133+
if buffer_idx != -1:
134+
# The data exists; update the alignment.
135+
self.buffers[buffer_idx].alignment = math.lcm(
136+
self.buffers[buffer_idx].alignment, alignment
137+
)
138+
else:
139+
# The data doesn't exist; add it to the data store.
140+
buffer_idx = len(self.buffers)
141+
self.buffers.append(BufferEntry(data, alignment))
142+
self.data_hash_to_buffer_idx[hashed] = buffer_idx
143+
144+
# Add key to the map and the key cache.
145+
local_key_to_buffer_idx[key] = buffer_idx
146+
self.key_to_buffer_idx[key] = buffer_idx
147+
148+
def add_named_data(
149+
self,
150+
key: str,
151+
data: bytes,
152+
alignment: Optional[int] = 1,
153+
external_tag: Optional[str] = None,
154+
) -> None:
155+
"""
156+
Adds a named blob to the NamedDataStore.
157+
Args:
158+
key (str): key associated with the data.
159+
data (bytes): Bytes being requested to be serialized.
160+
alignment (int): alignment for bytes to be serialized with.
161+
external (Optional[str]): the external filename that this data is saved to.
162+
Raises:
163+
ValueError: when the key exists in the store, and corresponding data
164+
is different.
165+
"""
166+
167+
# Set default alignment.
168+
if alignment is None:
169+
alignment = 1
170+
if alignment <= 0:
171+
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")
172+
173+
if external_tag is None:
174+
self._add_named_data_to_map(key, data, alignment, self.pte_data)
175+
else:
176+
self._add_named_data_to_map(
177+
key, data, alignment, self.external_data.setdefault(external_tag, {})
178+
)
179+
180+
def get_named_data_store_output(self) -> NamedDataStoreOutput:
181+
# Clean up empty maps inside self.external_data
182+
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
183+
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)

exir/_serialize/test/TARGETS

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
33
oncall("executorch")
44

55
python_unittest(
6-
name = "program",
6+
name = "test_program",
77
srcs = [
88
"test_program.py",
99
],
@@ -15,7 +15,7 @@ python_unittest(
1515
)
1616

1717
python_unittest(
18-
name = "flatbuffer",
18+
name = "test_flatbuffer",
1919
srcs = [
2020
"test_flatbuffer.py",
2121
],
@@ -25,11 +25,21 @@ python_unittest(
2525
)
2626

2727
python_unittest(
28-
name = "cord",
28+
name = "test_cord",
2929
srcs = [
3030
"test_cord.py",
3131
],
3232
deps = [
3333
"//executorch/exir/_serialize:lib",
3434
],
3535
)
36+
37+
python_unittest(
38+
name = "test_named_data_store",
39+
srcs = [
40+
"test_named_data_store.py",
41+
],
42+
deps = [
43+
"//executorch/exir/_serialize:lib",
44+
],
45+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore
12+
13+
14+
class TestNamedDataStore(unittest.TestCase):
15+
def test_add(self) -> None:
16+
store = NamedDataStore()
17+
store.add_named_data("key1", b"data1", None, None)
18+
store.add_named_data("key2", b"data2", 16, "file1")
19+
store.add_named_data("key3", b"data3", 16, "file1")
20+
21+
output = store.get_named_data_store_output()
22+
23+
self.assertEqual(len(output.buffers), 3)
24+
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
25+
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
26+
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))
27+
28+
self.assertEqual(len(output.pte_data), 1)
29+
self.assertEqual(output.pte_data["key1"], 0)
30+
31+
self.assertEqual(len(output.external_data), 1)
32+
self.assertEqual(len(output.external_data["file1"]), 2)
33+
self.assertEqual(output.external_data["file1"]["key2"], 1)
34+
self.assertEqual(output.external_data["file1"]["key3"], 2)
35+
36+
def test_add_duplicate_name_and_data(self) -> None:
37+
store = NamedDataStore()
38+
store.add_named_data("key", b"data", None, None)
39+
store.add_named_data("key", b"data", None, None)
40+
41+
output = store.get_named_data_store_output()
42+
43+
self.assertEqual(len(output.buffers), 1)
44+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
45+
46+
self.assertEqual(len(output.pte_data), 1)
47+
self.assertEqual(output.pte_data["key"], 0)
48+
49+
self.assertEqual(len(output.external_data), 0)
50+
51+
def test_add_same_data_with_different_alignment(self) -> None:
52+
store = NamedDataStore()
53+
store.add_named_data("key", b"data", 3, None)
54+
store.add_named_data("key1", b"data", 4, None)
55+
56+
output = store.get_named_data_store_output()
57+
58+
self.assertEqual(len(output.buffers), 1)
59+
# Check that we take the LCM of the two alignments (3, 4) = 12
60+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))
61+
62+
self.assertEqual(len(output.pte_data), 2)
63+
self.assertEqual(output.pte_data["key"], 0)
64+
self.assertEqual(output.pte_data["key1"], 0)
65+
66+
self.assertEqual(len(output.external_data), 0)
67+
68+
def test_add_duplicate_key_fail(self) -> None:
69+
store = NamedDataStore()
70+
store.add_named_data("key", b"data", None, None)
71+
72+
# Cannot add item with the same key and different data.
73+
self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None)
74+
self.assertRaises(
75+
ValueError, store.add_named_data, "key", b"data1", 16, "file1"
76+
)
77+
78+
output = store.get_named_data_store_output()
79+
80+
self.assertEqual(len(output.buffers), 1)
81+
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
82+
83+
self.assertEqual(len(output.pte_data), 1)
84+
self.assertEqual(output.pte_data["key"], 0)
85+
self.assertEqual(len(output.external_data), 0)

0 commit comments

Comments
 (0)