Skip to content

Commit 1091fc4

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Switch to DataLoader::load in runtime
Summary: Description - Makes the appropriate changes in runtime code to deprecate DataLoader::Load in favor of DataLoader::load with SegmentInfo. - Adds tests for the load function Follow-ups planned in diff stack - (D59606243) Currently `program.cpp` is passing in a hardcoded`"backend_segment"` as the descriptor, but in a follow-up will pass in the actual backend id. - Test `Constant` segment case Diff stack: - ( ) [1/n][executorch] Introduce new DataLoader::load() with segment info - (**x**) [2/n][executorch] Switch to DataLoader::load in runtime - ( ) [3/n][executorch] Pass in correct backend id into data load from runtime Differential Revision: D59594142
1 parent 091ad22 commit 1091fc4

File tree

2 files changed

+132
-14
lines changed

2 files changed

+132
-14
lines changed

runtime/executor/program.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
7272
size_t segment_base_offset = 0;
7373
{
7474
EXECUTORCH_SCOPE_PROF("Program::check_header");
75-
Result<FreeableBuffer> header =
76-
loader->Load(/*offset=*/0, ExtendedHeader::kNumHeadBytes);
75+
auto segment_info =
76+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program);
77+
Result<FreeableBuffer> header = loader->load(
78+
/*offset=*/0, ExtendedHeader::kNumHeadBytes, segment_info);
7779
if (!header.ok()) {
7880
return header.error();
7981
}
@@ -95,8 +97,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
9597

9698
// Load the flatbuffer data as a segment.
9799
uint32_t prof_tok = EXECUTORCH_BEGIN_PROF("Program::load_data");
100+
auto program_segment_info =
101+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program);
98102
Result<FreeableBuffer> program_data =
99-
loader->Load(/*offset=*/0, program_size);
103+
loader->load(/*offset=*/0, program_size, program_segment_info);
100104
if (!program_data.ok()) {
101105
return program_data.error();
102106
}
@@ -173,8 +177,13 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
173177

174178
const executorch_flatbuffer::DataSegment* data_segment =
175179
segments->Get(constant_segment->segment_index());
176-
Result<FreeableBuffer> constant_segment_data = loader->Load(
177-
segment_base_offset + data_segment->offset(), data_segment->size());
180+
auto constant_segment_info = DataLoader::SegmentInfo(
181+
DataLoader::SegmentInfo::Type::Constant,
182+
constant_segment->segment_index());
183+
Result<FreeableBuffer> constant_segment_data = loader->load(
184+
segment_base_offset + data_segment->offset(),
185+
data_segment->size(),
186+
constant_segment_info);
178187
if (!constant_segment_data.ok()) {
179188
return constant_segment_data.error();
180189
}
@@ -423,8 +432,10 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
423432
// Could fail if offset and size are out of bound for the data, or if this
424433
// is reading from a file and fails, or for many other reasons depending on
425434
// the implementation of the loader.
426-
return loader_->Load(
427-
segment_base_offset_ + segment->offset(), segment->size());
435+
auto segment_info = DataLoader::SegmentInfo(
436+
DataLoader::SegmentInfo::Type::Backend, index, "backend_segment");
437+
return loader_->load(
438+
segment_base_offset_ + segment->offset(), segment->size(), segment_info);
428439
}
429440

430441
} // namespace executor

runtime/executor/test/backend_integration_test.cpp

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
168168
public:
169169
/// A record of an operation performed on this DataLoader.
170170
struct Operation {
171-
enum { Load, Free } op;
172-
size_t offset; // Set for Load; zero for Free.
173-
void* data; // Set for Free; nullptr for Load.
174-
size_t size; // Set for Load and Free.
171+
enum { Load, Free, DeprecatedLoad } op;
172+
size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
173+
void* data; // Set for Free; nullptr for Load/DeprecatedLoad.
174+
size_t size; // Set for Load/DeprecatedLoad and Free.
175+
std::unique_ptr<const DataLoader::SegmentInfo>
176+
segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
175177
};
176178

177179
explicit DataLoaderSpy(DataLoader* delegate) : delegate_(delegate) {}
178180

181+
/**
182+
* Override the deprecated "Load" method. We will be looking to test that
183+
* this function is not called if the new "load" method is called.
184+
*/
179185
Result<FreeableBuffer> Load(size_t offset, size_t size) override {
180186
Result<FreeableBuffer> buf = delegate_->Load(offset, size);
181187
if (!buf.ok()) {
182188
return buf.error();
183189
}
184-
operations_.push_back({Operation::Load, offset, /*data=*/nullptr, size});
190+
operations_.push_back(
191+
{Operation::DeprecatedLoad,
192+
offset,
193+
/*data=*/nullptr,
194+
size,
195+
/*segment_info=*/nullptr});
196+
auto* context = new SpyContext(&operations_, std::move(buf.get()));
197+
// Use context->buffer since buf has been moved.
198+
return FreeableBuffer(
199+
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
200+
}
201+
202+
Result<FreeableBuffer>
203+
load(size_t offset, size_t size, const SegmentInfo& segment_info) override {
204+
Result<FreeableBuffer> buf = delegate_->load(offset, size, segment_info);
205+
if (!buf.ok()) {
206+
return buf.error();
207+
}
208+
209+
auto segment_info_cpy =
210+
std::make_unique<const DataLoader::SegmentInfo>(segment_info);
211+
operations_.push_back(
212+
{Operation::Load,
213+
offset,
214+
/*data=*/nullptr,
215+
size,
216+
/*segment_info=*/std::move(segment_info_cpy)});
185217
auto* context = new SpyContext(&operations_, std::move(buf.get()));
186218
// Use context->buffer since buf has been moved.
187219
return FreeableBuffer(
@@ -200,6 +232,33 @@ class DataLoaderSpy : public DataLoader {
200232
return operations_;
201233
}
202234

235+
/**
236+
* Returns true if the DataLoader::load() method was called with the correct
237+
* segment info.
238+
*/
239+
bool UsedLoad(
240+
DataLoader::SegmentInfo::Type segment_type,
241+
const char* descriptor) const {
242+
for (const auto& op : operations_) {
243+
// We should not be using the deprecated DataLoader::Load() function.
244+
if (op.op == Operation::DeprecatedLoad)
245+
return false;
246+
if (op.op != Operation::Load)
247+
continue;
248+
// We have a load op.
249+
if (op.segment_info->segment_type == segment_type) {
250+
if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
251+
// For non-backend segments, the descriptor is irrelevant / a nullptr.
252+
return true;
253+
} else {
254+
if (strcmp(op.segment_info->descriptor, descriptor) == 0)
255+
return true;
256+
}
257+
}
258+
}
259+
return false;
260+
}
261+
203262
/**
204263
* Returns true if the operations list shows that the provided data pointer
205264
* was freed.
@@ -223,7 +282,8 @@ class DataLoaderSpy : public DataLoader {
223282

224283
static void FreeBuffer(void* context, void* data, size_t size) {
225284
auto* sc = reinterpret_cast<SpyContext*>(context);
226-
sc->operations->push_back({Operation::Free, /*offset=*/0, data, size});
285+
sc->operations->push_back(
286+
{Operation::Free, /*offset=*/0, data, size, /*segment_info=*/nullptr});
227287
delete sc;
228288
}
229289

@@ -333,7 +393,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
333393
EXPECT_EQ(method_res.error(), Error::Ok);
334394

335395
// Demonstrate that our installed init was called.
336-
EXPECT_EQ(init_called, true);
396+
EXPECT_TRUE(init_called);
337397

338398
// See if the processed data was freed.
339399
bool processed_was_freed = spy_loader.WasFreed(processed_data);
@@ -444,6 +504,53 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
444504
EXPECT_EQ(execute_handle, destroy_handle);
445505
}
446506

507+
/**
508+
* Tests that the DataLoader's load is receiving the correct segment info for
509+
* different types of segments.
510+
*/
511+
TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
512+
const void* processed_data = nullptr;
513+
StubBackend::singleton().install_init(
514+
[&](FreeableBuffer* processed,
515+
__ET_UNUSED ArrayRef<CompileSpec> compile_specs,
516+
__ET_UNUSED MemoryAllocator* runtime_allocator)
517+
-> Result<DelegateHandle*> {
518+
processed_data = processed->data();
519+
processed->Free();
520+
return nullptr;
521+
});
522+
523+
// Wrap the real loader in a spy so we can see which operations were
524+
// performed.
525+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
526+
ASSERT_EQ(loader.error(), Error::Ok);
527+
DataLoaderSpy spy_loader(&loader.get());
528+
529+
// Load the program.
530+
Result<Program> program = Program::load(&spy_loader);
531+
ASSERT_EQ(program.error(), Error::Ok);
532+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
533+
534+
// Expect that load was called correctly on program segments.
535+
bool program_load_was_called =
536+
spy_loader.UsedLoad(DataLoader::SegmentInfo::Type::Program, nullptr);
537+
538+
// Load a method.
539+
Result<Method> method_res = program->load_method("forward", &mmm.get());
540+
EXPECT_EQ(method_res.error(), Error::Ok);
541+
542+
// Expect that load was called correctly on a backend segment.
543+
bool backend_load_was_called = spy_loader.UsedLoad(
544+
DataLoader::SegmentInfo::Type::Backend,
545+
"backend_segment"); // TODO(jackzhxng): replace with actual mock PTE
546+
// file's backend_id in next chained PR.
547+
548+
EXPECT_TRUE(program_load_was_called);
549+
if (using_segments()) {
550+
EXPECT_TRUE(backend_load_was_called);
551+
}
552+
}
553+
447554
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
448555
// - Errors during init() or execute() result in runtime init/execution failures
449556
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)