9
9
import tempfile
10
10
11
11
from dataclasses import dataclass , fields , is_dataclass
12
- from typing import ClassVar , Literal
12
+ from typing import ClassVar , List , Literal , Tuple
13
13
14
14
import pkg_resources
15
- from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import XNNGraph
15
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
16
+ Buffer ,
17
+ ConstantDataOffset ,
18
+ XNNGraph ,
19
+ )
16
20
from executorch .exir ._serialize ._dataclass import _DataclassEncoder
17
21
18
22
from executorch .exir ._serialize ._flatbuffer import _flatc_compile
@@ -236,6 +240,72 @@ def to_bytes(self) -> bytes:
236
240
return data
237
241
238
242
243
+ def _padding_required (offset : int , alignment : int ) -> int :
244
+ """Returns the padding required to align `offset` to `alignment`."""
245
+ remainder : int = offset % alignment
246
+ if remainder != 0 :
247
+ return alignment - remainder
248
+ return 0
249
+
250
+
251
+ def _aligned_size (input_size : int , alignment : int ) -> int :
252
+ """Returns input_size padded up to the next whole multiple of alignment."""
253
+ return input_size + _padding_required (input_size , alignment )
254
+
255
+
256
+ def _pad_to (data : bytes , length : int ) -> bytes :
257
+ """Returns the input followed by enough zero bytes to become the requested length.
258
+
259
+ Args:
260
+ data: The data to pad.
261
+ length: The length of the returned data.
262
+ Returns:
263
+ The padded data.
264
+ Raises:
265
+ ValueError: If the requested length is less than the input length.
266
+ """
267
+ if length < len (data ):
268
+ raise ValueError (f"Data length { len (data )} > padded length { length } " )
269
+ if length > len (data ):
270
+ data = data + b"\x00 " * (length - len (data ))
271
+ assert len (data ) == length
272
+ return data
273
+
274
+
275
+ def _extract_constant_data (
276
+ constant_buffer : List [Buffer ],
277
+ tensor_alignment : int ,
278
+ ) -> Tuple [bytes , List [int ]]:
279
+ """Copies the tensors from the provided list into a single buffer and tracks the offsets
280
+ of each tensor.
281
+
282
+ constant_buffer: list of Buffers from which to extract constants from. Not modified.
283
+ tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
284
+ constant segment will be aligned to this value. Default to 16.
285
+
286
+ Returns:
287
+ A tuple of (constant segment, list of offsets for each tensor in the segment)
288
+ """
289
+ constant_segment_data : bytearray = bytearray ()
290
+ constant_segment_offsets : List [int ] = []
291
+ current_offset : int = 0
292
+ for i in range (len (constant_buffer )):
293
+ buffer = constant_buffer [i ]
294
+ buffer_length = len (buffer .storage )
295
+ pad_length = _padding_required (buffer_length , tensor_alignment )
296
+
297
+ # Append each constant buffer to the constant segment.
298
+ constant_segment_data += buffer .storage
299
+ # Add padding for all but the last tensor.
300
+ if i < len (constant_buffer ) - 1 :
301
+ constant_segment_data += b"\x00 " * pad_length
302
+
303
+ # Append constant data offset.
304
+ constant_segment_offsets .append (current_offset )
305
+ current_offset += buffer_length + pad_length
306
+ return bytes (constant_segment_data ), constant_segment_offsets
307
+
308
+
239
309
def convert_to_flatbuffer (xnnpack_graph : XNNGraph ) -> bytes :
240
310
sanity_check_xnngraph_dataclass (xnnpack_graph )
241
311
xnnpack_graph_json = json .dumps (xnnpack_graph , cls = _DataclassEncoder )
@@ -251,3 +321,67 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
251
321
output_path = os .path .join (d , "schema.bin" )
252
322
with open (output_path , "rb" ) as output_file :
253
323
return output_file .read ()
324
+
325
+
326
+ def serialize_xnnpack_binary (xnnpack_graph : XNNGraph ) -> bytes :
327
+ """Returns the runtime binary representation of the given XNNGraph.
328
+
329
+ Args:
330
+ xnnpack_graph: XNNGraph object to serialize.
331
+
332
+ Returns:
333
+ The serialized form of the XNNGraph, ready for execution by XNNPACK Backend
334
+ """
335
+ constant_tensor_alignment = 16
336
+
337
+ # Extract constant data from the graph
338
+ constant_data , constant_data_offsets = _extract_constant_data (
339
+ xnnpack_graph .constant_buffer , constant_tensor_alignment
340
+ )
341
+
342
+ assert len (constant_data_offsets ) == len (xnnpack_graph .mem_buffer_sizes )
343
+
344
+ for offset_idx in range (len (constant_data_offsets )):
345
+ constant_data_offset = constant_data_offsets [offset_idx ]
346
+ constant_data_size = xnnpack_graph .mem_buffer_sizes [offset_idx ]
347
+ xnnpack_graph .constant_data .append (
348
+ ConstantDataOffset (constant_data_offset , constant_data_size )
349
+ )
350
+
351
+ # We are moving all constant data from the graph to the constant data section.
352
+ # So we remove all constant buffers
353
+ xnnpack_graph .constant_buffer = []
354
+ xnnpack_graph .mem_buffer_sizes = []
355
+
356
+ # Convert the XNNGraph to a flatbuffer
357
+ flatbuffer_payload = convert_to_flatbuffer (xnnpack_graph )
358
+
359
+ # size of flatbuffer data, padded to be 16 byte aligned
360
+ padded_flatbuffer_length : int = _aligned_size (
361
+ input_size = len (flatbuffer_payload ),
362
+ alignment = constant_tensor_alignment ,
363
+ )
364
+ # size of header to insert, padded to be `constant_tensor_alignment` byte aligned
365
+ padded_header_length : int = _aligned_size (
366
+ input_size = XNNHeader .EXPECTED_LENGTH ,
367
+ alignment = constant_tensor_alignment ,
368
+ )
369
+
370
+ # Create the XNNPACK Header
371
+ header : bytes = XNNHeader (
372
+ flatbuffer_offset = padded_header_length ,
373
+ flatbuffer_size = len (flatbuffer_payload ),
374
+ constant_data_offset = padded_header_length + padded_flatbuffer_length ,
375
+ constant_data_size = len (constant_data ),
376
+ ).to_bytes ()
377
+
378
+ # Concatenate the header, flatbuffer data, and constant data
379
+ # Constant data does not need to be padded to alignment because nothing follows it
380
+
381
+ return b"" .join (
382
+ [
383
+ _pad_to (header , padded_header_length ),
384
+ _pad_to (flatbuffer_payload , padded_flatbuffer_length ),
385
+ constant_data ,
386
+ ]
387
+ )
0 commit comments