Skip to content

Commit bb5d21c

Browse files
pytorchbotlucylq
andauthored
[executorch][runtime] Add get_named_data_map to Program (#8961)
* [executorch][runtime] Introduce PteDataMap for weight sharing Pull Request resolved: #8887 PteDataMap is the NamedDataMap that will live in the runtime. It is used to give delegates access to opaque named data stored in the PTE file. Open to alternative naming suggestions, maybe 'PTEDataMap' or 'ProgramDataMap'? **Usage** The PteDataMap is owned by the program, and instantiated at program load time if named_data exists in the PTE file. We introduce usage of 'std::optional' here. I think we can also use executorch::aten::optional to avoid adding standard lib ? When initializing delegates, the PteDataMap is given to delegate_init. Delegates can retrieve opaque delegate data by key using 'get_data'. This gives them a FreeableBuffer that they can free later. **Testing** This test uses the C++ flatbuffer API to build a fake program containing named data. We also creates a temp file with sample data that the data loader can wrap around. TODO: e2e test once delegate aot is ready and we can generate a file with named data. **Note** As the PteDataMap wraps around flatbuffer constructs, the Program must outlive the PteDataMap. PteDataMap does not implement - get_metadata; currently, all data stored is opaque. Later, we can implement get_metadata if a backend stores plain tensor data. - load_into; this is mostly used for the training case, and isn't used by delegates, at least not at the moment Differential Revision: [D70213646](https://our.internmc.facebook.com/intern/diff/D70213646/) ghstack-source-id: 269691307 * [executorch][runtime] Add get_named_data_map to Program Pull Request resolved: #8853 Add to the program interface, to allow users to retrieve the NDM. Differential Revision: [D70276106](https://our.internmc.facebook.com/intern/diff/D70276106/) ghstack-source-id: 269693108 --------- Co-authored-by: lucylq <[email protected]>
1 parent 6346348 commit bb5d21c

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

runtime/executor/program.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ Result<const void*> Program::get_constant_buffer_data(
373373
}
374374
}
375375

376+
Result<const NamedDataMap*> Program::get_named_data_map() const {
377+
if (pte_data_map_.has_value()) {
378+
return &pte_data_map_.value();
379+
}
380+
return Error::NotFound;
381+
}
382+
376383
Result<const char*> Program::get_output_flattening_encoding(
377384
const char* method_name) const {
378385
auto plan = get_execution_plan(internal_program_, method_name);

runtime/executor/program.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ class Program final {
110110
Result<const void*> get_constant_buffer_data(size_t buffer_idx, size_t nbytes)
111111
const;
112112

113+
/**
114+
* Get the named data map from the program.
115+
* @return The named data map.
116+
*/
117+
Result<const NamedDataMap*> get_named_data_map() const;
118+
113119
/**
114120
* Returns the number of methods in the program.
115121
*/

runtime/executor/test/program_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,18 @@ TEST_F(ProgramTest, getMethods) {
371371
EXPECT_EQ(strcmp(res2.get(), "forward2"), 0);
372372
}
373373

374+
TEST_F(ProgramTest, GetNamedDataMap_Fail) {
375+
Result<Program> program =
376+
Program::load(add_loader_.get(), kDefaultVerification);
377+
ASSERT_EQ(program.error(), Error::Ok);
378+
379+
// Get the named data map. Expect to fail, as add.pte does not have any
380+
// named data segments.
381+
Result<const executorch::runtime::NamedDataMap*> named_data_map =
382+
program->get_named_data_map();
383+
EXPECT_EQ(named_data_map.error(), Error::NotFound);
384+
}
385+
374386
// Test that the deprecated Load method (capital 'L') still works.
375387
TEST_F(ProgramTest, DEPRECATEDLoad) {
376388
// Parse the Program from the data.

0 commit comments

Comments
 (0)