|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import hashlib |
| 10 | +import math |
| 11 | +from dataclasses import dataclass |
| 12 | + |
| 13 | +# from dataclasses import dataclass |
| 14 | +from typing import Dict, List, Optional |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class BufferEntry: |
| 19 | + """A class to hold the buffer entries for serialization. |
| 20 | +
|
| 21 | + Attributes: |
| 22 | + buffer: The buffer bytes. |
| 23 | + alignment: The alignment of the buffer. |
| 24 | + """ |
| 25 | + |
| 26 | + buffer: bytes |
| 27 | + alignment: int |
| 28 | + |
| 29 | + |
| 30 | +@dataclass |
| 31 | +class NamedDataStoreOutput: |
| 32 | + """ |
| 33 | + Holds named data for serialization. |
| 34 | +
|
| 35 | + Attributes: |
| 36 | + buffers: A list of unique buffer entries. |
| 37 | + pte_data: Contains data that is stored inside the PTE file. A mapping from |
| 38 | + {key: buffer_index}. |
| 39 | + external_data: Contains data that is stored external to the PTE. A mapping |
| 40 | + from {filename: {key: buffer_index}}. |
| 41 | + """ |
| 42 | + |
| 43 | + buffers: List[BufferEntry] |
| 44 | + pte_data: Dict[str, int] |
| 45 | + external_data: Dict[str, Dict[str, int]] |
| 46 | + |
| 47 | + |
| 48 | +class NamedDataStore: |
| 49 | + """ |
| 50 | + NamedDataStore manages the data that delegates want to share. Backends add |
| 51 | + bytes to the store under a unique key. These bytes can be retrieved at |
| 52 | + runtime using the same key with the NamedDataMap. |
| 53 | +
|
| 54 | + Note: |
| 55 | + - Keys are unique in the data store, regardless of whether they are stored |
| 56 | + in the PTE or externally. |
| 57 | + - Multiple keys can point to the same buffer entry. |
| 58 | + - The same data can be added multiple times and all keys will point to one |
| 59 | + buffer. If a duplicate blob is added with a different alignment, the |
| 60 | + lcm of the current and new alignment is taken for that blob. |
| 61 | + """ |
| 62 | + |
| 63 | + # List of unique blobs. |
| 64 | + buffers: List[BufferEntry] |
| 65 | + # Named data stored inside the PTE file. Map of {key: buffer_index}. |
| 66 | + pte_data: Dict[str, int] |
| 67 | + # Named data stored outside of the PTE file. |
| 68 | + # Map of {filename: {key: buffer_index}}. |
| 69 | + external_data: Dict[str, Dict[str, int]] |
| 70 | + |
| 71 | + # Cache of the data hash for deduplication. |
| 72 | + # Use a hash instead of the data as a key because a sha256 collision is |
| 73 | + # unlikely, and the data may be large. |
| 74 | + data_hash_to_buffer_idx: Dict[bytes, int] |
| 75 | + # Cache of the key to buffer idx to ensure uniqueness. |
| 76 | + # If a key is added multiple times, check the buffer idx to ensure that the |
| 77 | + # data is identical too. |
| 78 | + key_to_buffer_idx: Dict[str, int] |
| 79 | + |
| 80 | + def __init__(self) -> None: |
| 81 | + """ |
| 82 | + Initializes a new NamedDataStore. |
| 83 | + """ |
| 84 | + self.buffers = [] |
| 85 | + self.pte_data = {} |
| 86 | + self.external_data = {} |
| 87 | + |
| 88 | + self.data_hash_to_buffer_idx = {} |
| 89 | + self.key_to_buffer_idx = {} |
| 90 | + |
| 91 | + def _add_named_data_to_map( |
| 92 | + self, |
| 93 | + key: str, |
| 94 | + data: bytes, |
| 95 | + alignment: int, |
| 96 | + local_key_to_buffer_idx: Dict[str, int], |
| 97 | + ) -> None: |
| 98 | + """ |
| 99 | + Add data to a map and update the alignment. Ensure that the key-data |
| 100 | + pair is unique. |
| 101 | + - If the key exists, the data must be identical. |
| 102 | + - If multiple unique keys exist for the same data, those keys should |
| 103 | + point to the same buffer. |
| 104 | +
|
| 105 | + Args: |
| 106 | + key (str): key associated with the data. |
| 107 | + data (bytes): Bytes being requested to be serialized. |
| 108 | + alignment (int): alignment for bytes to be serialized with. |
| 109 | + local_key_to_buffer_idx (Dict[str, int]): map to add the data to. |
| 110 | + Raises: |
| 111 | + ValueError: when the key exists in the store, and corresponding data |
| 112 | + is different. |
| 113 | + """ |
| 114 | + # Get data hash. |
| 115 | + hashed = hashlib.sha256(data).digest() |
| 116 | + |
| 117 | + # Check if the key exists. |
| 118 | + buffer_idx = self.key_to_buffer_idx.get(key, -1) |
| 119 | + if buffer_idx != -1: |
| 120 | + # If the key exists, the corresponding data must be identical. |
| 121 | + if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx: |
| 122 | + raise ValueError( |
| 123 | + f"Duplicate key {key} with different data. " |
| 124 | + f"Existing data: {self.buffers[buffer_idx].buffer}. " |
| 125 | + f"New data: {data}." |
| 126 | + ) |
| 127 | + self.buffers[buffer_idx].alignment = math.lcm( |
| 128 | + self.buffers[buffer_idx].alignment, alignment |
| 129 | + ) |
| 130 | + else: |
| 131 | + # Key doesn't exist; check if the data exists. |
| 132 | + buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1) |
| 133 | + if buffer_idx != -1: |
| 134 | + # The data exists; update the alignment. |
| 135 | + self.buffers[buffer_idx].alignment = math.lcm( |
| 136 | + self.buffers[buffer_idx].alignment, alignment |
| 137 | + ) |
| 138 | + else: |
| 139 | + # The data doesn't exist; add it to the data store. |
| 140 | + buffer_idx = len(self.buffers) |
| 141 | + self.buffers.append(BufferEntry(data, alignment)) |
| 142 | + self.data_hash_to_buffer_idx[hashed] = buffer_idx |
| 143 | + |
| 144 | + # Add key to the map and the key cache. |
| 145 | + local_key_to_buffer_idx[key] = buffer_idx |
| 146 | + self.key_to_buffer_idx[key] = buffer_idx |
| 147 | + |
| 148 | + def add_named_data( |
| 149 | + self, |
| 150 | + key: str, |
| 151 | + data: bytes, |
| 152 | + alignment: Optional[int] = 1, |
| 153 | + external_tag: Optional[str] = None, |
| 154 | + ) -> None: |
| 155 | + """ |
| 156 | + Adds a named blob to the NamedDataStore. |
| 157 | + Args: |
| 158 | + key (str): key associated with the data. |
| 159 | + data (bytes): Bytes being requested to be serialized. |
| 160 | + alignment (int): alignment for bytes to be serialized with. |
| 161 | + external (Optional[str]): the external filename that this data is saved to. |
| 162 | + Raises: |
| 163 | + ValueError: when the key exists in the store, and corresponding data |
| 164 | + is different. |
| 165 | + """ |
| 166 | + |
| 167 | + # Set default alignment. |
| 168 | + if alignment is None: |
| 169 | + alignment = 1 |
| 170 | + if alignment <= 0: |
| 171 | + raise ValueError(f"Alignment must be greater than 0, received {alignment}.") |
| 172 | + |
| 173 | + if external_tag is None: |
| 174 | + self._add_named_data_to_map(key, data, alignment, self.pte_data) |
| 175 | + else: |
| 176 | + self._add_named_data_to_map( |
| 177 | + key, data, alignment, self.external_data.setdefault(external_tag, {}) |
| 178 | + ) |
| 179 | + |
| 180 | + def get_named_data_store_output(self) -> NamedDataStoreOutput: |
| 181 | + # Clean up empty maps inside self.external_data |
| 182 | + self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0} |
| 183 | + return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data) |
0 commit comments