9
9
import tempfile
10
10
11
11
from dataclasses import dataclass , fields , is_dataclass
12
- from typing import ClassVar , Literal
12
+ from typing import ClassVar , List , Literal , Tuple
13
13
14
14
import pkg_resources
15
- from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import XNNGraph
15
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
16
+ Buffer ,
17
+ ConstantDataOffset ,
18
+ XNNGraph ,
19
+ )
16
20
from executorch .exir ._serialize ._dataclass import _DataclassEncoder
17
21
18
22
from executorch .exir ._serialize ._flatbuffer import _flatc_compile
@@ -236,6 +240,74 @@ def to_bytes(self) -> bytes:
236
240
return data
237
241
238
242
243
+ def _padding_required (offset : int , alignment : int ) -> int :
244
+ """Returns the padding required to align `offset` to `alignment`."""
245
+ remainder : int = offset % alignment
246
+ if remainder != 0 :
247
+ return alignment - remainder
248
+ return 0
249
+
250
+
251
+ def _aligned_size (input_size : int , alignment : int ) -> int :
252
+ """Returns input_size padded up to the next whole multiple of alignment."""
253
+ aligned_size = input_size + _padding_required (input_size , alignment )
254
+ assert aligned_size % alignment == 0
255
+ return aligned_size
256
+
257
+
258
+ def _pad_to (data : bytes , length : int ) -> bytes :
259
+ """Returns the input followed by enough zero bytes to become the requested length.
260
+
261
+ Args:
262
+ data: The data to pad.
263
+ length: The length of the returned data.
264
+ Returns:
265
+ The padded data.
266
+ Raises:
267
+ ValueError: If the requested length is less than the input length.
268
+ """
269
+ if length < len (data ):
270
+ raise ValueError (f"Data length { len (data )} > padded length { length } " )
271
+ if length > len (data ):
272
+ data = data + b"\x00 " * (length - len (data ))
273
+ assert len (data ) == length
274
+ return data
275
+
276
+
277
+ def _extract_constant_data (
278
+ constant_buffer : List [Buffer ],
279
+ tensor_alignment : int = 16 ,
280
+ ) -> Tuple [bytes , List [int ]]:
281
+ """Copies the tensors from the provided list into a single buffer and tracks the offsets
282
+ of each tensor.
283
+
284
+ constant_buffer: list of Buffers from which to extract constants from. Not modified.
285
+ tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
286
+ constant segment will be aligned to this value. Default to 16.
287
+
288
+ Returns:
289
+ A tuple of (constant segment, list of offsets for each tensor in the segment)
290
+ """
291
+ constant_segment_data : bytearray = bytearray ()
292
+ constant_segment_offsets : List [int ] = []
293
+ current_offset : int = 0
294
+ for i in range (len (constant_buffer )):
295
+ buffer = constant_buffer [i ]
296
+ buffer_length = len (buffer .storage )
297
+ pad_length = _padding_required (buffer_length , tensor_alignment )
298
+
299
+ # Append each constant buffer to the constant segment.
300
+ constant_segment_data += buffer .storage
301
+ # Add padding for all but the last tensor.
302
+ if i < len (constant_buffer ) - 1 :
303
+ constant_segment_data += b"\x00 " * pad_length
304
+
305
+ # Append constant data offset.
306
+ constant_segment_offsets .append (current_offset )
307
+ current_offset += buffer_length + pad_length
308
+ return bytes (constant_segment_data ), constant_segment_offsets
309
+
310
+
239
311
def convert_to_flatbuffer (xnnpack_graph : XNNGraph ) -> bytes :
240
312
sanity_check_xnngraph_dataclass (xnnpack_graph )
241
313
xnnpack_graph_json = json .dumps (xnnpack_graph , cls = _DataclassEncoder )
@@ -251,3 +323,67 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
251
323
output_path = os .path .join (d , "schema.bin" )
252
324
with open (output_path , "rb" ) as output_file :
253
325
return output_file .read ()
326
+
327
+
328
+ def serialize_xnnpack_binary (xnnpack_graph : XNNGraph ) -> bytes :
329
+ """Returns the runtime binary representation of the given XNNGraph.
330
+
331
+ Args:
332
+ xnnpack_graph: XNNGraph object to serialize.
333
+
334
+ Returns:
335
+ The serialized form of the XNNGraph, ready for execution by XNNPACK Backend
336
+ """
337
+ constant_tensor_alignment = 16
338
+
339
+ # Extract constant data from the graph
340
+ constant_data , constant_data_offsets = _extract_constant_data (
341
+ xnnpack_graph .constant_buffer , constant_tensor_alignment
342
+ )
343
+
344
+ assert len (constant_data_offsets ) == len (xnnpack_graph .mem_buffer_sizes )
345
+
346
+ for offset_idx in range (len (constant_data_offsets )):
347
+ constant_data_offset = constant_data_offsets [offset_idx ]
348
+ constant_data_size = xnnpack_graph .mem_buffer_sizes [offset_idx ]
349
+ xnnpack_graph .constant_data .append (
350
+ ConstantDataOffset (constant_data_offset , constant_data_size )
351
+ )
352
+
353
+ # We are moving all constant data from the graph to the constant data section.
354
+ # So we remove all constant buffers
355
+ xnnpack_graph .constant_buffer = []
356
+ xnnpack_graph .mem_buffer_sizes = []
357
+
358
+ # Convert the XNNGraph to a flatbuffer
359
+ flatbuffer_payload = convert_to_flatbuffer (xnnpack_graph )
360
+
361
+ # size of flatbuffer data, padded to be `constant_tensor_alignment` byte aligned
362
+ padded_flatbuffer_length : int = _aligned_size (
363
+ input_size = len (flatbuffer_payload ),
364
+ alignment = constant_tensor_alignment ,
365
+ )
366
+ # size of header to insert, padded to be `constant_tensor_alignment` byte aligned
367
+ padded_header_length : int = _aligned_size (
368
+ input_size = XNNHeader .EXPECTED_LENGTH ,
369
+ alignment = constant_tensor_alignment ,
370
+ )
371
+
372
+ # Create the XNNPACK Header
373
+ header : bytes = XNNHeader (
374
+ flatbuffer_offset = padded_header_length ,
375
+ flatbuffer_size = len (flatbuffer_payload ),
376
+ constant_data_offset = padded_header_length + padded_flatbuffer_length ,
377
+ constant_data_size = len (constant_data ),
378
+ ).to_bytes ()
379
+
380
+ # Concatenate the header, flatbuffer data, and constant data
381
+ # Constant data does not need to be padded to alignment because nothing follows it
382
+
383
+ return b"" .join (
384
+ [
385
+ _pad_to (header , padded_header_length ),
386
+ _pad_to (flatbuffer_payload , padded_flatbuffer_length ),
387
+ constant_data ,
388
+ ]
389
+ )
0 commit comments