Skip to content

Commit f7b8034

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Pass in correct backend id into data load from runtime
Summary: The runtime now passes in the backend id into the SegmentInfo for its load() call. Diff stack: - ( ) [1/n][executorch] Introduce new DataLoader::load() with segment info - ( ) [2/n][executorch] Switch to DataLoader::load in runtime - (**x**) [3/n][executorch] Pass in correct backend id into data load from runtime Differential Revision: D59606243
1 parent e71bcb0 commit f7b8034

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

runtime/executor/method.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ class BackendDelegate final {
182182
/*free_fn=*/nullptr);
183183
}
184184
case executorch_flatbuffer::DataLocation::SEGMENT: {
185-
return program->LoadSegment(processed->index());
185+
const char* backend_id = delegate.id()->c_str();
186+
return program->LoadSegment(DataLoader::SegmentInfo(
187+
DataLoader::SegmentInfo::Type::Backend,
188+
processed->index(),
189+
backend_id));
186190
}
187191
default:
188192
ET_LOG(

runtime/executor/program.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,10 @@ Error Program::get_backend_delegate_data(
415415
return HeaderStatus::NotPresent;
416416
}
417417

418-
Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
418+
Result<FreeableBuffer> Program::LoadSegment(
419+
const DataLoader::SegmentInfo& segment_info) const {
419420
EXECUTORCH_SCOPE_PROF("Program::LoadSegment");
421+
size_t index = segment_info.segment_index;
420422
if (loader_ == nullptr || segment_base_offset_ == 0) {
421423
ET_LOG(Error, "No segments in program: requested index %zu", index);
422424
return Error::NotFound;
@@ -432,8 +434,6 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
432434
// Could fail if offset and size are out of bound for the data, or if this
433435
// is reading from a file and fails, or for many other reasons depending on
434436
// the implementation of the loader.
435-
auto segment_info = DataLoader::SegmentInfo(
436-
DataLoader::SegmentInfo::Type::Backend, index, "backend_segment");
437437
return loader_->load(
438438
segment_base_offset_ + segment->offset(), segment->size(), segment_info);
439439
}

runtime/executor/program.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ class Program final {
239239
/**
240240
* Loads a segment by index.
241241
*
242-
* @param[in] index The sement index to load. This should be an index into
243-
* the Program.segments list.
242+
* @param SegmentInfo Struct containing an index to load from the
243+
* Program.segments list.
244244
*
245245
* @returns The data as a FreeableBuffer, if the index is valid.
246246
* @retval Error::NotFound The program does not contain any segments or the
@@ -249,7 +249,8 @@ class Program final {
249249
* DataLoader: The Program.segment table is inconsistent, or the
250250
* data cannot be accessed.
251251
*/
252-
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(size_t index) const;
252+
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(
253+
const DataLoader::SegmentInfo& segment_info) const;
253254

254255
private:
255256
Program(

runtime/executor/test/backend_integration_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
542542
// Expect that load was called correctly on a backend segment.
543543
bool backend_load_was_called = spy_loader.UsedLoad(
544544
DataLoader::SegmentInfo::Type::Backend,
545-
"backend_segment"); // TODO(jackzhxng): replace with actual mock PTE
546-
// file's backend_id in next chained PR.
545+
"StubBackend"); // This backend id is taken from the StubBackend defined
546+
// in export_delegated_program.py.
547547

548548
EXPECT_TRUE(program_load_was_called);
549549
if (using_segments()) {

runtime/executor/test/program_test.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <gtest/gtest.h>
2525

2626
using namespace ::testing;
27+
using torch::executor::DataLoader;
2728
using torch::executor::Error;
2829
using torch::executor::FreeableBuffer;
2930
using torch::executor::Program;
@@ -87,8 +88,8 @@ class ProgramTestFriend final {
8788
public:
8889
__ET_NODISCARD static Result<FreeableBuffer> LoadSegment(
8990
const Program* program,
90-
size_t index) {
91-
return program->LoadSegment(index);
91+
const DataLoader::SegmentInfo& segment_info) {
92+
return program->LoadSegment(segment_info);
9293
}
9394

9495
const static executorch_flatbuffer::Program* GetInternalProgram(
@@ -227,14 +228,18 @@ TEST_F(ProgramTest, UnalignedProgramDataFails) {
227228
}
228229

229230
TEST_F(ProgramTest, LoadSegmentWithNoSegments) {
230-
// Load a program with no segments.
231+
// Load a program with no appended segments.
231232
Result<Program> program =
232233
Program::load(add_loader_.get(), kDefaultVerification);
233234
EXPECT_EQ(program.error(), Error::Ok);
234235

235-
// Loading a segment should fail.
236+
// Loading a non-program segment should fail.
237+
const auto segment_info = DataLoader::SegmentInfo(
238+
DataLoader::SegmentInfo::Type::Backend,
239+
/*segment_index=*/0,
240+
"some-backend");
236241
Result<FreeableBuffer> segment =
237-
ProgramTestFriend::LoadSegment(&program.get(), 0);
242+
ProgramTestFriend::LoadSegment(&program.get(), segment_info);
238243
EXPECT_NE(segment.error(), Error::Ok);
239244
}
240245

@@ -351,8 +356,12 @@ TEST_F(ProgramTest, LoadConstantSegment) {
351356
ASSERT_EQ(program.error(), Error::Ok);
352357

353358
// Load constant segment data.
359+
const auto segment_info = DataLoader::SegmentInfo(
360+
DataLoader::SegmentInfo::Type::Constant,
361+
/*segment_index=*/0,
362+
/*descriptor=*/nullptr);
354363
Result<FreeableBuffer> segment =
355-
ProgramTestFriend::LoadSegment(&program.get(), 0);
364+
ProgramTestFriend::LoadSegment(&program.get(), segment_info);
356365
EXPECT_EQ(segment.error(), Error::Ok);
357366

358367
const executorch_flatbuffer::Program* flatbuffer_program =

0 commit comments

Comments
 (0)