9
9
import tempfile
10
10
11
11
from dataclasses import dataclass , fields , is_dataclass
12
- from typing import ClassVar , Literal
12
+ from typing import ClassVar , List , Literal , Tuple
13
13
14
14
import pkg_resources
15
- from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import XNNGraph
15
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
16
+ Buffer ,
17
+ ConstantDataOffset ,
18
+ XNNGraph ,
19
+ )
16
20
from executorch .exir ._serialize ._dataclass import _DataclassEncoder
17
21
18
22
from executorch .exir ._serialize ._flatbuffer import _flatc_compile
@@ -148,6 +152,72 @@ def to_bytes(self) -> bytes:
148
152
return data
149
153
150
154
155
+ def _padding_required (offset : int , alignment : int ) -> int :
156
+ """Returns the padding required to align `offset` to `alignment`."""
157
+ remainder : int = offset % alignment
158
+ if remainder != 0 :
159
+ return alignment - remainder
160
+ return 0
161
+
162
+
163
+ def _aligned_size (input_size : int , alignment : int ) -> int :
164
+ """Returns input_size padded up to the next whole multiple of alignment."""
165
+ return input_size + _padding_required (input_size , alignment )
166
+
167
+
168
+ def _pad_to (data : bytes , length : int ) -> bytes :
169
+ """Returns the input followed by enough zero bytes to become the requested length.
170
+
171
+ Args:
172
+ data: The data to pad.
173
+ length: The length of the returned data.
174
+ Returns:
175
+ The padded data.
176
+ Raises:
177
+ ValueError: If the requested length is less than the input length.
178
+ """
179
+ if length < len (data ):
180
+ raise ValueError (f"Data length { len (data )} > padded length { length } " )
181
+ if length > len (data ):
182
+ data = data + b"\x00 " * (length - len (data ))
183
+ assert len (data ) == length
184
+ return data
185
+
186
+
187
+ def extract_constant_data (
188
+ constant_buffer : List [Buffer ],
189
+ tensor_alignment : int ,
190
+ ) -> Tuple [bytes , List [int ]]:
191
+ """Copies the tensors from the provided list into a single buffer and tracks the offsets
192
+ of each tensor.
193
+
194
+ constant_buffer: list of Buffers from which to extract constants from. Not modified.
195
+ tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
196
+ constant segment will be aligned to this value. Default to 16.
197
+
198
+ Returns:
199
+ A tuple of (constant segment, list of offsets for each tensor in the segment)
200
+ """
201
+ constant_segment_data : bytearray = bytearray ()
202
+ constant_segment_offsets : List [int ] = []
203
+ current_offset : int = 0
204
+ for i in range (len (constant_buffer )):
205
+ buffer = constant_buffer [i ]
206
+ buffer_length = len (buffer .storage )
207
+ pad_length = _padding_required (buffer_length , tensor_alignment )
208
+
209
+ # Append each constant buffer to the constant segment.
210
+ constant_segment_data += buffer .storage
211
+ # Add padding for all but the last tensor.
212
+ if i < len (constant_buffer ) - 1 :
213
+ constant_segment_data += b"\x00 " * pad_length
214
+
215
+ # Append constant data offset.
216
+ constant_segment_offsets .append (current_offset )
217
+ current_offset += buffer_length + pad_length
218
+ return bytes (constant_segment_data ), constant_segment_offsets
219
+
220
+
151
221
def convert_to_flatbuffer (xnnpack_graph : XNNGraph ) -> bytes :
152
222
sanity_check_xnngraph_dataclass (xnnpack_graph )
153
223
xnnpack_graph_json = json .dumps (xnnpack_graph , cls = _DataclassEncoder )
@@ -163,3 +233,67 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
163
233
output_path = os .path .join (d , "schema.bin" )
164
234
with open (output_path , "rb" ) as output_file :
165
235
return output_file .read ()
236
+
237
+
238
+ def serialize_xnnpack_binary (xnnpack_graph : XNNGraph ) -> bytes :
239
+ """Returns the runtime binary representation of the given XNNGraph.
240
+
241
+ Args:
242
+ xnnpack_graph: XNNGraph object to serialize.
243
+
244
+ Returns:
245
+ The serialized form of the XNNGraph, ready for execution by XNNPACK Backend
246
+ """
247
+ constant_tensor_alignment = 16
248
+
249
+ # Extract constant data from the graph
250
+ constant_data , constant_data_offsets = extract_constant_data (
251
+ xnnpack_graph .constant_buffer , constant_tensor_alignment
252
+ )
253
+
254
+ assert len (constant_data_offsets ) == len (xnnpack_graph .mem_buffer_sizes )
255
+
256
+ for offset_idx in range (len (constant_data_offsets )):
257
+ constant_data_offset = constant_data_offsets [offset_idx ]
258
+ constant_data_size = xnnpack_graph .mem_buffer_sizes [offset_idx ]
259
+ xnnpack_graph .constant_data .append (
260
+ ConstantDataOffset (constant_data_offset , constant_data_size )
261
+ )
262
+
263
+ # We are moving all constant data from the graph to the constant data section.
264
+ # So we remove all constant buffers except the first one
265
+ xnnpack_graph .constant_buffer = []
266
+ xnnpack_graph .mem_buffer_sizes = []
267
+
268
+ # Convert the XNNGraph to a flatbuffer
269
+ flatbuffer_payload = convert_to_flatbuffer (xnnpack_graph )
270
+
271
+ # size of flatbuffer data, padded to be 16 byte aligned
272
+ padded_flatbuffer_length : int = _aligned_size (
273
+ input_size = len (flatbuffer_payload ),
274
+ alignment = constant_tensor_alignment ,
275
+ )
276
+ # size of header to insert, padded to be 16 byte aligned
277
+ padded_header_length : int = _aligned_size (
278
+ input_size = XNNHeader .EXPECTED_LENGTH ,
279
+ alignment = constant_tensor_alignment ,
280
+ )
281
+
282
+ # Create the XNNPACK Header
283
+ header : bytes = XNNHeader (
284
+ flatbuffer_offset = padded_header_length ,
285
+ flatbuffer_size = len (flatbuffer_payload ),
286
+ constant_data_offset = padded_header_length + padded_flatbuffer_length ,
287
+ constant_data_size = len (constant_data ),
288
+ ).to_bytes ()
289
+
290
+ # Concatenate the header, flatbuffer data, and constant data
291
+ # Constant data does not need to be padded to alignment because nothing follows it
292
+
293
+ return b"" .join (
294
+ [
295
+ _pad_to (header , padded_header_length ),
296
+ _pad_to (flatbuffer_payload , padded_flatbuffer_length ),
297
+ constant_data ,
298
+ ]
299
+ )
0 commit comments