Skip to content

Commit 93a7725

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Pass in correct backend id into data load from runtime (pytorch#4218)
Summary: Pull Request resolved: pytorch#4218 The runtime now passes in the backend id into the SegmentInfo for its load() call. Diff stack: - ( ) [[1/n][executorch] Introduce new DataLoader::load() with segment info](https://www.internalfb.com/diff/D59399538?dst_version_fbid=2979141262226992) - ( ) [[2/n][executorch] Switch to DataLoader::load in runtime](https://www.internalfb.com/diff/D59594142) - (**x**) [[3/n][executorch] Pass in correct backend id into data load from runtime](https://www.internalfb.com/diff/D59606243) Reviewed By: dbort Differential Revision: D59606243 fbshipit-source-id: b4af5e1bb2e53f576ab75a65d2ef950cc051e8c9
1 parent c7e407e commit 93a7725

File tree

5 files changed

+32
-19
lines changed

5 files changed

+32
-19
lines changed

runtime/executor/method.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ class BackendDelegate final {
182182
/*free_fn=*/nullptr);
183183
}
184184
case executorch_flatbuffer::DataLocation::SEGMENT: {
185-
return program->LoadSegment(processed->index());
185+
const char* backend_id = delegate.id()->c_str();
186+
return program->LoadSegment(DataLoader::SegmentInfo(
187+
DataLoader::SegmentInfo::Type::Backend,
188+
processed->index(),
189+
backend_id));
186190
}
187191
default:
188192
ET_LOG(

runtime/executor/program.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,10 @@ Error Program::get_backend_delegate_data(
414414
return HeaderStatus::NotPresent;
415415
}
416416

417-
Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
417+
Result<FreeableBuffer> Program::LoadSegment(
418+
const DataLoader::SegmentInfo& segment_info) const {
418419
EXECUTORCH_SCOPE_PROF("Program::LoadSegment");
420+
size_t index = segment_info.segment_index;
419421
if (loader_ == nullptr || segment_base_offset_ == 0) {
420422
ET_LOG(Error, "No segments in program: requested index %zu", index);
421423
return Error::NotFound;
@@ -431,12 +433,8 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
431433
// Could fail if offset and size are out of bound for the data, or if this
432434
// is reading from a file and fails, or for many other reasons depending on
433435
// the implementation of the loader.
434-
// TODO(jackzhxng): "backend_segment" is a hardcode, pass in real backend id.
435436
return loader_->load(
436-
segment_base_offset_ + segment->offset(),
437-
segment->size(),
438-
DataLoader::SegmentInfo(
439-
DataLoader::SegmentInfo::Type::Backend, index, "backend_segment"));
437+
segment_base_offset_ + segment->offset(), segment->size(), segment_info);
440438
}
441439

442440
} // namespace executor

runtime/executor/program.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ class Program final {
239239
/**
240240
* Loads a segment by index.
241241
*
242-
* @param[in] index The sement index to load. This should be an index into
243-
* the Program.segments list.
242+
* @param[in] SegmentInfo Struct containing an index to load from the
243+
* Program.segments list. The other fields of the struct, such as
244+
* `segment_type` and `descriptor`, need to also be correct.
244245
*
245246
* @returns The data as a FreeableBuffer, if the index is valid.
246247
* @retval Error::NotFound The program does not contain any segments or the
@@ -249,7 +250,8 @@ class Program final {
249250
* DataLoader: The Program.segment table is inconsistent, or the
250251
* data cannot be accessed.
251252
*/
252-
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(size_t index) const;
253+
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(
254+
const DataLoader::SegmentInfo& segment_info) const;
253255

254256
private:
255257
Program(

runtime/executor/test/backend_integration_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,8 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
545545
// Expect that load was called correctly on a backend segment.
546546
bool backend_load_was_called = spy_loader.UsedLoad(
547547
DataLoader::SegmentInfo::Type::Backend,
548-
"backend_segment"); // TODO(jackzhxng): replace with actual mock PTE
549-
// file's backend_id in next chained PR.
548+
"StubBackend"); // This backend id is taken from the StubBackend defined
549+
// in export_delegated_program.py.
550550

551551
EXPECT_TRUE(program_load_was_called);
552552
EXPECT_EQ(backend_load_was_called, using_segments());

runtime/executor/test/program_test.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <gtest/gtest.h>
2525

2626
using namespace ::testing;
27+
using torch::executor::DataLoader;
2728
using torch::executor::Error;
2829
using torch::executor::FreeableBuffer;
2930
using torch::executor::Program;
@@ -87,8 +88,8 @@ class ProgramTestFriend final {
8788
public:
8889
__ET_NODISCARD static Result<FreeableBuffer> LoadSegment(
8990
const Program* program,
90-
size_t index) {
91-
return program->LoadSegment(index);
91+
const DataLoader::SegmentInfo& segment_info) {
92+
return program->LoadSegment(segment_info);
9293
}
9394

9495
const static executorch_flatbuffer::Program* GetInternalProgram(
@@ -227,14 +228,18 @@ TEST_F(ProgramTest, UnalignedProgramDataFails) {
227228
}
228229

229230
TEST_F(ProgramTest, LoadSegmentWithNoSegments) {
230-
// Load a program with no segments.
231+
// Load a program with no appended segments.
231232
Result<Program> program =
232233
Program::load(add_loader_.get(), kDefaultVerification);
233234
EXPECT_EQ(program.error(), Error::Ok);
234235

235-
// Loading a segment should fail.
236+
// Loading a non-program segment should fail.
237+
const auto segment_info = DataLoader::SegmentInfo(
238+
DataLoader::SegmentInfo::Type::Backend,
239+
/*segment_index=*/0,
240+
"some-backend");
236241
Result<FreeableBuffer> segment =
237-
ProgramTestFriend::LoadSegment(&program.get(), 0);
242+
ProgramTestFriend::LoadSegment(&program.get(), segment_info);
238243
EXPECT_NE(segment.error(), Error::Ok);
239244
}
240245

@@ -350,9 +355,13 @@ TEST_F(ProgramTest, LoadConstantSegment) {
350355
Result<Program> program = Program::load(&linear_loader.get());
351356
ASSERT_EQ(program.error(), Error::Ok);
352357

353-
// Load constant segment data.
358+
// Load constant segment data, which is currently always in segment index
359+
// zero.
360+
const auto segment_info = DataLoader::SegmentInfo(
361+
DataLoader::SegmentInfo::Type::Constant,
362+
/*segment_index=*/0);
354363
Result<FreeableBuffer> segment =
355-
ProgramTestFriend::LoadSegment(&program.get(), 0);
364+
ProgramTestFriend::LoadSegment(&program.get(), segment_info);
356365
EXPECT_EQ(segment.error(), Error::Ok);
357366

358367
const executorch_flatbuffer::Program* flatbuffer_program =

0 commit comments

Comments
 (0)