Skip to content

Commit 55f9975

Browse files
mcr229facebook-github-bot
authored andcommitted
Serialize constant Data outside of flatbuffer
Summary: We introduce the `serialize_xnnpack_binary` method which serializees the constant data outside of the flatbuffer. It leverages the xnnheader introduced in the previous diff to store offsets and sizes for both the flatbuffer payload as well as the constant data payload. Note here we have not yet switched the delegate to use the new `serialize_xnnpack_binary` function as this new serialization also requires changes on the runtime side. This will be tested in the diff which follows. Differential Revision: D52498367
1 parent 994536f commit 55f9975

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed

backends/xnnpack/serialization/schema.fbs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,17 @@ table XNNLeakyReLU {
281281
flags: uint;
282282
}
283283

284+
// Describes data offsets for constant data
285+
table ConstantDataOffset {
286+
// Constant data offsets are relative to the constant data base offset provided
287+
// in the XNNPACKHeader.
288+
offset: uint32;
289+
290+
// The size in bytes of valid data starting at the offset. The constant data
291+
// may be followed by padding before the next piece of constant data
292+
size: uint32;
293+
}
294+
284295
table XNNGraph {
285296
// Schema version.
286297
version:string;
@@ -299,11 +310,16 @@ table XNNGraph {
299310
// Tables of constant data, used for constant Values (e.g.
300311
// data field of weight tensors). Each constant is assigned an index into the table
301312
// which are each individually aligned. 0 index is reserved to be pointed to by non-constant
302-
// Tensors
313+
// Tensors. Both constant_buffer and constant_data may not both be non-empty
303314
constant_buffer:[Buffer];
304315

305316
// the list index is memory buffer id, the value is the memory buffer size.
306317
mem_buffer_sizes: [uint];
318+
319+
// List of the constant data that follows the XNNGraph in this file. Each constant data is assigned an index into
320+
// the table. 0 index is reserved to be pointed to by non-constant Tensor. Both constant_buffer and constant_data
321+
// may not both be non-empty
322+
constant_data:[ConstantDataOffset];
307323
}
308324

309325
root_type XNNGraph;

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ class Buffer:
417417
storage: bytes
418418

419419

420+
@dataclass
421+
class ConstantDataOffset:
422+
offset: int
423+
size: int
424+
425+
420426
@dataclass
421427
class XNNGraph:
422428
version: str
@@ -429,3 +435,5 @@ class XNNGraph:
429435

430436
constant_buffer: List[Buffer]
431437
mem_buffer_sizes: List[int]
438+
439+
constant_data: List[ConstantDataOffset]

backends/xnnpack/serialization/xnnpack_graph_serialize.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
import tempfile
1010

1111
from dataclasses import dataclass, fields, is_dataclass
12-
from typing import ClassVar, Literal
12+
from typing import ClassVar, List, Literal, Tuple
1313

1414
import pkg_resources
15-
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
15+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16+
Buffer,
17+
ConstantDataOffset,
18+
XNNGraph,
19+
)
1620
from executorch.exir._serialize._dataclass import _DataclassEncoder
1721

1822
from executorch.exir._serialize._flatbuffer import _flatc_compile
@@ -148,6 +152,72 @@ def to_bytes(self) -> bytes:
148152
return data
149153

150154

155+
def _padding_required(offset: int, alignment: int) -> int:
156+
"""Returns the padding required to align `offset` to `alignment`."""
157+
remainder: int = offset % alignment
158+
if remainder != 0:
159+
return alignment - remainder
160+
return 0
161+
162+
163+
def _aligned_size(input_size: int, alignment: int) -> int:
164+
"""Returns input_size padded up to the next whole multiple of alignment."""
165+
return input_size + _padding_required(input_size, alignment)
166+
167+
168+
def _pad_to(data: bytes, length: int) -> bytes:
169+
"""Returns the input followed by enough zero bytes to become the requested length.
170+
171+
Args:
172+
data: The data to pad.
173+
length: The length of the returned data.
174+
Returns:
175+
The padded data.
176+
Raises:
177+
ValueError: If the requested length is less than the input length.
178+
"""
179+
if length < len(data):
180+
raise ValueError(f"Data length {len(data)} > padded length {length}")
181+
if length > len(data):
182+
data = data + b"\x00" * (length - len(data))
183+
assert len(data) == length
184+
return data
185+
186+
187+
def extract_constant_data(
188+
constant_buffer: List[Buffer],
189+
tensor_alignment: int,
190+
) -> Tuple[bytes, List[int]]:
191+
"""Copies the tensors from the provided list into a single buffer and tracks the offsets
192+
of each tensor.
193+
194+
constant_buffer: list of Buffers from which to extract constants from. Not modified.
195+
tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
196+
constant segment will be aligned to this value. Default to 16.
197+
198+
Returns:
199+
A tuple of (constant segment, list of offsets for each tensor in the segment)
200+
"""
201+
constant_segment_data: bytearray = bytearray()
202+
constant_segment_offsets: List[int] = []
203+
current_offset: int = 0
204+
for i in range(len(constant_buffer)):
205+
buffer = constant_buffer[i]
206+
buffer_length = len(buffer.storage)
207+
pad_length = _padding_required(buffer_length, tensor_alignment)
208+
209+
# Append each constant buffer to the constant segment.
210+
constant_segment_data += buffer.storage
211+
# Add padding for all but the last tensor.
212+
if i < len(constant_buffer) - 1:
213+
constant_segment_data += b"\x00" * pad_length
214+
215+
# Append constant data offset.
216+
constant_segment_offsets.append(current_offset)
217+
current_offset += buffer_length + pad_length
218+
return bytes(constant_segment_data), constant_segment_offsets
219+
220+
151221
def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
152222
sanity_check_xnngraph_dataclass(xnnpack_graph)
153223
xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder)
@@ -163,3 +233,67 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
163233
output_path = os.path.join(d, "schema.bin")
164234
with open(output_path, "rb") as output_file:
165235
return output_file.read()
236+
237+
238+
def serialize_xnnpack_binary(xnnpack_graph: XNNGraph) -> bytes:
239+
"""Returns the runtime binary representation of the given XNNGraph.
240+
241+
Args:
242+
xnnpack_graph: XNNGraph object to serialize.
243+
244+
Returns:
245+
The serialized form of the XNNGraph, ready for execution by XNNPACK Backend
246+
"""
247+
constant_tensor_alignment = 16
248+
249+
# Extract constant data from the graph
250+
constant_data, constant_data_offsets = extract_constant_data(
251+
xnnpack_graph.constant_buffer, constant_tensor_alignment
252+
)
253+
254+
assert len(constant_data_offsets) == len(xnnpack_graph.mem_buffer_sizes)
255+
256+
for offset_idx in range(len(constant_data_offsets)):
257+
constant_data_offset = constant_data_offsets[offset_idx]
258+
constant_data_size = xnnpack_graph.mem_buffer_sizes[offset_idx]
259+
xnnpack_graph.constant_data.append(
260+
ConstantDataOffset(constant_data_offset, constant_data_size)
261+
)
262+
263+
# We are moving all constant data from the graph to the constant data section.
264+
# So we remove all constant buffers except the first one
265+
xnnpack_graph.constant_buffer = []
266+
xnnpack_graph.mem_buffer_sizes = []
267+
268+
# Convert the XNNGraph to a flatbuffer
269+
flatbuffer_payload = convert_to_flatbuffer(xnnpack_graph)
270+
271+
# size of flatbuffer data, padded to be 16 byte aligned
272+
padded_flatbuffer_length: int = _aligned_size(
273+
input_size=len(flatbuffer_payload),
274+
alignment=constant_tensor_alignment,
275+
)
276+
# size of header to insert, padded to be 16 byte aligned
277+
padded_header_length: int = _aligned_size(
278+
input_size=XNNHeader.EXPECTED_LENGTH,
279+
alignment=constant_tensor_alignment,
280+
)
281+
282+
# Create the XNNPACK Header
283+
header: bytes = XNNHeader(
284+
flatbuffer_offset=padded_header_length,
285+
flatbuffer_size=len(flatbuffer_payload),
286+
constant_data_offset=padded_header_length + padded_flatbuffer_length,
287+
constant_data_size=len(constant_data),
288+
).to_bytes()
289+
290+
# Concatenate the header, flatbuffer data, and constant data
291+
# Constant data does not need to be padded to alignment because nothing follows it
292+
293+
return b"".join(
294+
[
295+
_pad_to(header, padded_header_length),
296+
_pad_to(flatbuffer_payload, padded_flatbuffer_length),
297+
constant_data,
298+
]
299+
)

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def preprocess(
232232
output_ids=[],
233233
constant_buffer=[Buffer(storage=b"")],
234234
mem_buffer_sizes=[0],
235+
constant_data=[],
235236
)
236237

237238
node_visitors = get_node_visitors(ep, node_to_external_map)

0 commit comments

Comments
 (0)