10
10
#include < executorch/backends/xnnpack/runtime/XNNHeader.h>
11
11
#include < executorch/backends/xnnpack/serialization/schema_generated.h>
12
12
#include < executorch/extension/threadpool/threadpool.h>
13
- #include < executorch/runtime/core/exec_aten/util/scalar_type_util .h>
13
+ #include < executorch/runtime/executor/pte_data_map .h>
14
14
#include < unordered_map>
15
15
16
16
#pragma clang diagnostic ignored "-Wmissing-prototypes"
@@ -22,7 +22,9 @@ namespace xnnpack {
22
22
namespace delegate {
23
23
24
24
using executorch::runtime::Error;
25
+ using executorch::runtime::FreeableBuffer;
25
26
using executorch::runtime::MemoryAllocator;
27
+ using executorch::runtime::NamedDataMap;
26
28
using executorch::runtime::Result;
27
29
28
30
/*
@@ -48,6 +50,7 @@ class CompileAllocator {
48
50
using ValuePtr = const fb_xnnpack::XValue*;
49
51
using NodePtr = const fb_xnnpack::XNode*;
50
52
using GraphPtr = const fb_xnnpack::XNNGraph*;
53
+ using ConstantDataOffsetPtr = const fb_xnnpack::ConstantDataOffset*;
51
54
using DataType = fb_xnnpack::XNNDatatype;
52
55
53
56
// Type for define node function. This is the function signature
@@ -162,7 +165,9 @@ data associated with the tensor value, then returns nullptr.
162
165
const uint8_t * getConstantDataPtr (
163
166
const fb_xnnpack::XNNTensorValue* tensor_value,
164
167
GraphPtr flatbuffer_graph,
165
- const uint8_t * constant_data_ptr) {
168
+ const uint8_t * constant_data_ptr,
169
+ const NamedDataMap* named_data_map,
170
+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
166
171
auto buffer_idx = tensor_value->constant_buffer_idx ();
167
172
if (buffer_idx) {
168
173
if (!constant_data_ptr) {
@@ -171,10 +176,31 @@ const uint8_t* getConstantDataPtr(
171
176
const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
172
177
return constant_buffer[buffer_idx]->storage ()->data ();
173
178
} else {
174
- const auto & constant_data_offsets = *flatbuffer_graph->constant_data ();
175
- uint64_t constant_data_offset =
176
- constant_data_offsets[buffer_idx]->offset ();
177
- return constant_data_ptr + constant_data_offset;
179
+ ConstantDataOffsetPtr constant_data_offset =
180
+ flatbuffer_graph->constant_data ()->Get (buffer_idx);
181
+ uint64_t offset = constant_data_offset->offset ();
182
+
183
+ bool has_named_key = flatbuffers::IsFieldPresent (
184
+ constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
185
+ // If there is no tensor name
186
+ if (!has_named_key) {
187
+ return constant_data_ptr + offset;
188
+ } else {
189
+ const std::string& data_name = constant_data_offset->named_key ()->str ();
190
+ Result<FreeableBuffer> buffer =
191
+ named_data_map->get_data (data_name.c_str ());
192
+ if (!buffer.ok ()) {
193
+ ET_LOG (
194
+ Error,
195
+ " Failed to get constant data for key %s" ,
196
+ data_name.c_str ());
197
+ return nullptr ;
198
+ }
199
+ const uint8_t * data_ptr =
200
+ static_cast <const uint8_t *>(buffer.get ().data ());
201
+ loaded_buffers_from_map.push_back (std::move (buffer.get ()));
202
+ return data_ptr;
203
+ }
178
204
}
179
205
}
180
206
@@ -194,7 +220,9 @@ Error defineTensor(
194
220
const uint8_t * constant_data_ptr,
195
221
std::vector<uint32_t >& input_ids,
196
222
std::vector<uint32_t >& output_ids,
197
- CompileAllocator& allocator) {
223
+ CompileAllocator& allocator,
224
+ const NamedDataMap* named_data_map,
225
+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
198
226
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
199
227
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
200
228
@@ -231,8 +259,12 @@ Error defineTensor(
231
259
232
260
// Get Pointer to constant data from flatbuffer, if its non-constant
233
261
// it is a nullptr
234
- const uint8_t * buffer_ptr =
235
- getConstantDataPtr (tensor_value, flatbuffer_graph, constant_data_ptr);
262
+ const uint8_t * buffer_ptr = getConstantDataPtr (
263
+ tensor_value,
264
+ flatbuffer_graph,
265
+ constant_data_ptr,
266
+ named_data_map,
267
+ loaded_buffers_from_map);
236
268
237
269
xnn_status status;
238
270
// The type we might have to convert to
@@ -1968,6 +2000,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
1968
2000
size_t num_bytes,
1969
2001
XNNExecutor* executor,
1970
2002
MemoryAllocator* runtime_allocator,
2003
+ const NamedDataMap* named_data_map,
1971
2004
xnn_workspace_t workspace) {
1972
2005
Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
1973
2006
const uint8_t * flatbuffer_data = nullptr ;
@@ -2036,6 +2069,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2036
2069
std::vector<uint32_t > input_ids;
2037
2070
std::vector<uint32_t > output_ids;
2038
2071
Error err = Error::Ok;
2072
+ std::vector<FreeableBuffer> loaded_buffers_from_map;
2039
2073
for (auto value : *flatbuffer_graph->xvalues ()) {
2040
2074
err = defineTensor (
2041
2075
subgraph.get (),
@@ -2045,7 +2079,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
2045
2079
constant_data,
2046
2080
input_ids,
2047
2081
output_ids,
2048
- compile_allocator);
2082
+ compile_allocator,
2083
+ named_data_map,
2084
+ loaded_buffers_from_map);
2049
2085
2050
2086
if (err != Error::Ok) {
2051
2087
return err;
0 commit comments