Skip to content

Introduce NamedDataStore #8587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ runtime.python_library(
"_cord.py",
"_dataclass.py",
"_flatbuffer.py",
"_named_data_store.py",
"_program.py",
"_serialize.py",
"data_serializer.py",
Expand Down
183 changes: 183 additions & 0 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import hashlib
import math
from dataclasses import dataclass

# from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class BufferEntry:
"""A class to hold the buffer entries for serialization.

Attributes:
buffer: The buffer bytes.
alignment: The alignment of the buffer.
"""

buffer: bytes
alignment: int


@dataclass
class NamedDataStoreOutput:
"""
Holds named data for serialization.

Attributes:
buffers: A list of unique buffer entries.
pte_data: Contains data that is stored inside the PTE file. A mapping from
{key: buffer_index}.
external_data: Contains data that is stored external to the PTE. A mapping
from {filename: {key: buffer_index}}.
"""

buffers: List[BufferEntry]
pte_data: Dict[str, int]
external_data: Dict[str, Dict[str, int]]


class NamedDataStore:
"""
NamedDataStore manages the data that delegates want to share. Backends add
bytes to the store under a unique key. These bytes can be retrieved at
runtime using the same key with the NamedDataMap.

Note:
- Keys are unique in the data store, regardless of whether they are stored
in the PTE or externally.
- Multiple keys can point to the same buffer entry.
- The same data can be added multiple times and all keys will point to one
buffer. If a duplicate blob is added with a different alignment, the
lcm of the current and new alignment is taken for that blob.
"""

# List of unique blobs.
buffers: List[BufferEntry]
# Named data stored inside the PTE file. Map of {key: buffer_index}.
pte_data: Dict[str, int]
# Named data stored outside of the PTE file.
# Map of {filename: {key: buffer_index}}.
external_data: Dict[str, Dict[str, int]]

# Cache of the data hash for deduplication.
# Use a hash instead of the data as a key because a sha256 collision is
# unlikely, and the data may be large.
data_hash_to_buffer_idx: Dict[bytes, int]
# Cache of the key to buffer idx to ensure uniqueness.
# If a key is added multiple times, check the buffer idx to ensure that the
# data is identical too.
key_to_buffer_idx: Dict[str, int]

def __init__(self) -> None:
"""
Initializes a new NamedDataStore.
"""
self.buffers = []
self.pte_data = {}
self.external_data = {}

self.data_hash_to_buffer_idx = {}
self.key_to_buffer_idx = {}

def _add_named_data_to_map(
self,
key: str,
data: bytes,
alignment: int,
local_key_to_buffer_idx: Dict[str, int],
) -> None:
"""
Add data to a map and update the alignment. Ensure that the key-data
pair is unique.
- If the key exists, the data must be identical.
- If multiple unique keys exist for the same data, those keys should
point to the same buffer.

Args:
key (str): key associated with the data.
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
local_key_to_buffer_idx (Dict[str, int]): map to add the data to.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
"""
# Get data hash.
hashed = hashlib.sha256(data).digest()

# Check if the key exists.
buffer_idx = self.key_to_buffer_idx.get(key, -1)
if buffer_idx != -1:
# If the key exists, the corresponding data must be identical.
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx].buffer}. "
f"New data: {data}."
)
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
# Key doesn't exist; check if the data exists.
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
if buffer_idx != -1:
# The data exists; update the alignment.
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
# The data doesn't exist; add it to the data store.
buffer_idx = len(self.buffers)
self.buffers.append(BufferEntry(data, alignment))
self.data_hash_to_buffer_idx[hashed] = buffer_idx

# Add key to the map and the key cache.
local_key_to_buffer_idx[key] = buffer_idx
self.key_to_buffer_idx[key] = buffer_idx

def add_named_data(
self,
key: str,
data: bytes,
alignment: Optional[int] = 1,
external_tag: Optional[str] = None,
) -> None:
"""
Adds a named blob to the NamedDataStore.
Args:
key (str): key associated with the data.
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
external (Optional[str]): the external filename that this data is saved to.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
"""

# Set default alignment.
if alignment is None:
alignment = 1
if alignment <= 0:
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")

if external_tag is None:
self._add_named_data_to_map(key, data, alignment, self.pte_data)
else:
self._add_named_data_to_map(
key, data, alignment, self.external_data.setdefault(external_tag, {})
)

def get_named_data_store_output(self) -> NamedDataStoreOutput:
# Clean up empty maps inside self.external_data
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)
16 changes: 13 additions & 3 deletions exir/_serialize/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
oncall("executorch")

python_unittest(
name = "program",
name = "test_program",
srcs = [
"test_program.py",
],
Expand All @@ -15,7 +15,7 @@ python_unittest(
)

python_unittest(
name = "flatbuffer",
name = "test_flatbuffer",
srcs = [
"test_flatbuffer.py",
],
Expand All @@ -25,11 +25,21 @@ python_unittest(
)

python_unittest(
name = "cord",
name = "test_cord",
srcs = [
"test_cord.py",
],
deps = [
"//executorch/exir/_serialize:lib",
],
)

python_unittest(
name = "test_named_data_store",
srcs = [
"test_named_data_store.py",
],
deps = [
"//executorch/exir/_serialize:lib",
],
)
85 changes: 85 additions & 0 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore


class TestNamedDataStore(unittest.TestCase):
def test_add(self) -> None:
store = NamedDataStore()
store.add_named_data("key1", b"data1", None, None)
store.add_named_data("key2", b"data2", 16, "file1")
store.add_named_data("key3", b"data3", 16, "file1")

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 3)
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key1"], 0)

self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(output.external_data["file1"]["key2"], 1)
self.assertEqual(output.external_data["file1"]["key3"], 2)

def test_add_duplicate_name_and_data(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", None, None)
store.add_named_data("key", b"data", None, None)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)

self.assertEqual(len(output.external_data), 0)

def test_add_same_data_with_different_alignment(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", 3, None)
store.add_named_data("key1", b"data", 4, None)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
# Check that we take the LCM of the two alignments (3, 4) = 12
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))

self.assertEqual(len(output.pte_data), 2)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key1"], 0)

self.assertEqual(len(output.external_data), 0)

def test_add_duplicate_key_fail(self) -> None:
store = NamedDataStore()
store.add_named_data("key", b"data", None, None)

# Cannot add item with the same key and different data.
self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None)
self.assertRaises(
ValueError, store.add_named_data, "key", b"data1", 16, "file1"
)

output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(len(output.external_data), 0)
Loading