Skip to content

Commit 77defc6

Browse files
mcr229facebook-github-bot
authored andcommitted
Handle XNNHeader in XNNPACK Runtime (#1543)
Summary: We introduce XNNHeader on runtime side to handle the newly introduced XNNHeader ahead of time. XNNHeader manages the offsets and sizes of the flatbuffer payload and the constant data payload so that it is accessible by the XNNCompiler It is important to note that on serialization side, we have not yet switched our serialization method to `serialize_xnnpack_binary` so this does not yet use the new serialization format. However, passing tests on this illustrates BC as old models will still be able to run on this new runtime. Passing tests here show that the Header Magic correctly works in discerning between using the XNNHeader and the Flatbuffer header Reviewed By: digantdesai Differential Revision: D52556131
1 parent c14ab50 commit 77defc6

File tree

3 files changed

+256
-8
lines changed

3 files changed

+256
-8
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
10+
#include <executorch/backends/xnnpack/runtime/XNNHeader.h>
1011
#include <executorch/backends/xnnpack/schema_generated.h>
1112
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
1213
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -103,6 +104,34 @@ std::vector<T> flatbufferDimsToVector(
103104
return dims_data;
104105
}
105106

107+
/**
108+
Gets the constant data pointer associated with the given tensor value.
109+
Obtaining the constant data pointer can either be from within the flatbuffer
110+
payload (deprecated) or via offsets to the constant_data_ptr. If no constant
111+
data associated with the tensor value, then returns nullptr.
112+
*/
113+
const uint8_t* getConstantDataPtr(
114+
const fb_xnnpack::XNNTensorValue* tensor_value,
115+
GraphPtr flatbuffer_graph,
116+
const uint8_t* constant_data_ptr) {
117+
auto buffer_idx = tensor_value->constant_buffer_idx();
118+
if (buffer_idx) {
119+
if (!constant_data_ptr) {
120+
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
121+
// window
122+
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
123+
return constant_buffer[buffer_idx]->storage()->data();
124+
} else {
125+
const auto& constant_data_offsets = *flatbuffer_graph->constant_data();
126+
uint64_t constant_data_offset =
127+
constant_data_offsets[buffer_idx]->offset();
128+
return constant_data_ptr + constant_data_offset;
129+
}
130+
}
131+
132+
return nullptr;
133+
}
134+
106135
/**
107136
Define serialized tensor value into
108137
the subgraph. While also keeping track of the remapped ids from
@@ -113,6 +142,7 @@ Error defineTensor(
113142
std::unordered_map<uint32_t, uint32_t>& remapped_ids,
114143
ValuePtr value,
115144
GraphPtr flatbuffer_graph,
145+
const uint8_t* constant_data_ptr,
116146
XNNExecutor* executor,
117147
MemoryAllocator* runtime_allocator) {
118148
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
@@ -151,11 +181,9 @@ Error defineTensor(
151181

152182
// Get Pointer to constant data from flatbuffer, if its non-constant
153183
// it is a nullptr
154-
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
155-
auto buffer_idx = tensor_value->constant_buffer_idx();
156-
const auto buffer_ptr = buffer_idx == 0
157-
? nullptr
158-
: constant_buffer[buffer_idx]->storage()->data();
184+
const uint8_t* buffer_ptr =
185+
getConstantDataPtr(tensor_value, flatbuffer_graph, constant_data_ptr);
186+
159187
xnn_status status;
160188
// The type we might have to convert to
161189
auto dq_datatype = getDataType(tensor_value->dq_datatype());
@@ -1429,14 +1457,31 @@ __ET_NODISCARD Error XNNCompiler::compileModel(
14291457
size_t num_bytes,
14301458
XNNExecutor* executor,
14311459
MemoryAllocator* runtime_allocator) {
1460+
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
1461+
const uint8_t* flatbuffer_data = nullptr;
1462+
const uint8_t* constant_data = nullptr;
1463+
1464+
// Header status can only either be Error::Ok or Error::NotFound
1465+
if (header.ok()) {
1466+
flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1467+
header->flatbuffer_offset;
1468+
constant_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1469+
header->constant_data_offset;
1470+
} else if (header.error() == Error::NotFound) {
1471+
flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer);
1472+
} else {
1473+
ET_LOG(Error, "XNNHeader may be corrupt");
1474+
return header.error();
1475+
}
1476+
14321477
ET_CHECK_OR_RETURN_ERROR(
1433-
fb_xnnpack::XNNGraphBufferHasIdentifier(buffer_pointer),
1478+
fb_xnnpack::XNNGraphBufferHasIdentifier(flatbuffer_data),
14341479
DelegateInvalidCompatibility,
14351480
"XNNPACK Delegate Serialization Format version identifier '%.4s' != expected '%.4s'",
1436-
flatbuffers::GetBufferIdentifier(buffer_pointer),
1481+
flatbuffers::GetBufferIdentifier(flatbuffer_data),
14371482
fb_xnnpack::XNNGraphIdentifier());
14381483

1439-
auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer);
1484+
auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data);
14401485
// initialize xnnpack
14411486
xnn_status status = xnn_initialize(/*allocator =*/nullptr);
14421487
ET_CHECK_OR_RETURN_ERROR(
@@ -1476,6 +1521,7 @@ __ET_NODISCARD Error XNNCompiler::compileModel(
14761521
remapped_ids,
14771522
value,
14781523
flatbuffer_graph,
1524+
constant_data,
14791525
executor,
14801526
runtime_allocator);
14811527

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/xnnpack/runtime/XNNHeader.h>
10+
11+
#include <cstring>
12+
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/result.h>
15+
16+
#pragma clang diagnostic ignored "-Wdeprecated"
17+
18+
namespace torch {
19+
namespace executor {
20+
namespace xnnpack {
21+
namespace delegate {
22+
23+
namespace {
24+
/// Interprets the 8 bytes at `data` as a little-endian uint64_t.
25+
uint64_t GetUInt64LE(const uint8_t* data) {
26+
return (uint64_t)data[0] | ((uint64_t)data[1] << 8) |
27+
((uint64_t)data[2] << 16) | ((uint64_t)data[3] << 24) |
28+
((uint64_t)data[4] << 32) | ((uint64_t)data[5] << 40) |
29+
((uint64_t)data[6] << 48) | ((uint64_t)data[7] << 56);
30+
}
31+
32+
/// Interprets the 4 bytes at `data` as a little-endian uint32_t.
33+
uint32_t GetUInt32LE(const uint8_t* data) {
34+
return (uint32_t)data[0] | ((uint32_t)data[1] << 8) |
35+
((uint32_t)data[2] << 16) | ((uint32_t)data[3] << 24);
36+
}
37+
38+
} // namespace
39+
40+
Result<XNNHeader> XNNHeader::Parse(const void* data, size_t size) {
41+
const uint8_t* header_data = (const uint8_t*)data;
42+
43+
if (size < XNNHeader::kMinSize) {
44+
return Error::InvalidArgument;
45+
}
46+
47+
const uint8_t* magic_start = header_data + XNNHeader::kMagicOffset;
48+
if (std::memcmp(magic_start, XNNHeader::kMagic, XNNHeader::kMagicSize) != 0) {
49+
return Error::NotFound;
50+
}
51+
52+
uint32_t flatbuffer_offset =
53+
GetUInt32LE(header_data + XNNHeader::kFlatbufferDataOffsetOffset);
54+
55+
uint32_t flatbuffer_size =
56+
GetUInt32LE(header_data + XNNHeader::kFlatbufferDataSizeOffset);
57+
58+
uint32_t constant_data_offset =
59+
GetUInt32LE(header_data + XNNHeader::kConstantDataOffsetOffset);
60+
61+
uint64_t constant_data_size =
62+
GetUInt64LE(header_data + XNNHeader::kConstantDataSizeOffset);
63+
64+
return XNNHeader{
65+
flatbuffer_offset,
66+
flatbuffer_size,
67+
constant_data_offset,
68+
constant_data_size};
69+
}
70+
71+
// Define storage for the static.
72+
constexpr char XNNHeader::kMagic[kMagicSize];
73+
74+
} // namespace delegate
75+
} // namespace xnnpack
76+
} // namespace executor
77+
} // namespace torch

backends/xnnpack/runtime/XNNHeader.h

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/result.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace xnnpack {
16+
namespace delegate {
17+
18+
/**
19+
* An extended XNNPACK-header that is embeded before the flatbuffer payload
20+
*
21+
*/
22+
struct XNNHeader {
23+
/**
24+
* The minimum size of the XNNHeader. The caller should provide at least this
25+
* many bytes of the head of the serialized XNNPACK Data
26+
*/
27+
static constexpr size_t kMinSize = 30;
28+
29+
/**
30+
* The magic offset. This offset is the same as the offset for flatbuffer
31+
* header so we will be able to check if the header is is either the
32+
* flatbuffer head or the wrapper header we introduce here
33+
*/
34+
static constexpr size_t kMagicOffset = 4;
35+
36+
/**
37+
* The magic bytes that identify the header.
38+
*
39+
* This is the canonical definition of the expected value. If the header
40+
* layout ever changes in a compatibility-breaking way, increment the digits
41+
* in the magic. But, doing so will prevent older binaries from recognizing
42+
* the presence of the header. The compatibility-preserving way to make
43+
* changes is to increase the header's length field and add new fields at the
44+
* end.
45+
*/
46+
static constexpr size_t kMagicSize = 4;
47+
static constexpr char kMagic[kMagicSize] = {'X', 'H', '0', '0'};
48+
49+
/**
50+
* The size in bytes of the header length. We store 2 bytes for the header
51+
* length
52+
*/
53+
static constexpr size_t kHeaderLengthSize = 2;
54+
55+
/**
56+
* The expected location of the header length field relative to the beginning
57+
* of the header.
58+
*/
59+
static constexpr size_t kHeaderLengthOffset =
60+
XNNHeader::kMagicOffset + XNNHeader::kMagicSize;
61+
62+
/**
63+
* The expected location of the flatbuffer data offset field relative to the
64+
* beginning of the header.
65+
*/
66+
static constexpr size_t kFlatbufferDataOffsetOffset =
67+
kHeaderLengthOffset + sizeof(uint16_t);
68+
69+
/**
70+
* The expected location of the flatbuffer data size field relative to the
71+
* beginning of the header.
72+
*/
73+
static constexpr size_t kFlatbufferDataSizeOffset =
74+
kFlatbufferDataOffsetOffset + sizeof(uint32_t);
75+
76+
/*
77+
* The expected location of the constant data offset field relative to the
78+
* beginning of the header.
79+
*/
80+
static constexpr size_t kConstantDataOffsetOffset =
81+
kFlatbufferDataSizeOffset + sizeof(uint32_t);
82+
83+
/*
84+
* The expected location of the constant data size field relative to the
85+
* beginning of the header.
86+
*/
87+
static constexpr size_t kConstantDataSizeOffset =
88+
kConstantDataOffsetOffset + sizeof(uint32_t);
89+
90+
/**
91+
* Look for and parse an ExtendedHeader in the provided data.
92+
*
93+
* @param[in] data The contents of the beginning of the serialized binary
94+
* Program data, starting at offset 0 (i.e., the head of the file).
95+
* @param[in] size Length of `data` in bytes.
96+
*
97+
* @returns an XNNHeader if the header was found and is valid. Returns an
98+
* error if size was too short, if the header was not found, or if the
99+
* header appeared to be corrupt.
100+
*/
101+
static Result<XNNHeader> Parse(const void* data, size_t size);
102+
103+
/**
104+
* The offset in bytes to the beginning of the flatbuffer data.
105+
*/
106+
uint32_t flatbuffer_offset;
107+
/**
108+
* The size in bytes of the flatbuffer data.
109+
*/
110+
uint32_t flatbuffer_size;
111+
112+
/**
113+
* The offset in bytes to the beginning of the constant data.
114+
*/
115+
uint32_t constant_data_offset;
116+
/**
117+
* The size in bytes of the constant data.
118+
*/
119+
uint64_t constant_data_size;
120+
};
121+
122+
} // namespace delegate
123+
} // namespace xnnpack
124+
} // namespace executor
125+
} // namespace torch

0 commit comments

Comments
 (0)