Skip to content

Commit c7e407e

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Switch to DataLoader::load in runtime (pytorch#4217)
Summary: Pull Request resolved: pytorch#4217 Description - Makes the appropriate changes in runtime code to deprecate DataLoader::Load in favor of DataLoader::load with SegmentInfo. - Adds tests for the load function Follow-ups planned in diff stack - (D59606243) Currently `program.cpp` is passing in a hardcoded`"backend_segment"` as the descriptor, but in a follow-up will pass in the actual backend id. - Test `Constant` segment case Diff stack: - ( ) [[1/n][executorch] Introduce new DataLoader::load() with segment info](https://www.internalfb.com/diff/D59399538?dst_version_fbid=2979141262226992) - (**x**) [[2/n][executorch] Switch to DataLoader::load in runtime](https://www.internalfb.com/diff/D59594142) - ( ) [[3/n][executorch] Pass in correct backend id into data load from runtime](https://www.internalfb.com/diff/D59606243) Reviewed By: dbort Differential Revision: D59594142 fbshipit-source-id: d579ea2eb9ee4b7f6e38877b156f340292efef30
1 parent bda0f7c commit c7e407e

File tree

2 files changed

+135
-15
lines changed

2 files changed

+135
-15
lines changed

runtime/executor/program.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
7272
size_t segment_base_offset = 0;
7373
{
7474
EXECUTORCH_SCOPE_PROF("Program::check_header");
75-
Result<FreeableBuffer> header =
76-
loader->Load(/*offset=*/0, ExtendedHeader::kNumHeadBytes);
75+
Result<FreeableBuffer> header = loader->load(
76+
/*offset=*/0,
77+
ExtendedHeader::kNumHeadBytes,
78+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
7779
if (!header.ok()) {
7880
return header.error();
7981
}
@@ -95,8 +97,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
9597

9698
// Load the flatbuffer data as a segment.
9799
uint32_t prof_tok = EXECUTORCH_BEGIN_PROF("Program::load_data");
98-
Result<FreeableBuffer> program_data =
99-
loader->Load(/*offset=*/0, program_size);
100+
Result<FreeableBuffer> program_data = loader->load(
101+
/*offset=*/0,
102+
program_size,
103+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
100104
if (!program_data.ok()) {
101105
return program_data.error();
102106
}
@@ -173,8 +177,12 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
173177

174178
const executorch_flatbuffer::DataSegment* data_segment =
175179
segments->Get(constant_segment->segment_index());
176-
Result<FreeableBuffer> constant_segment_data = loader->Load(
177-
segment_base_offset + data_segment->offset(), data_segment->size());
180+
Result<FreeableBuffer> constant_segment_data = loader->load(
181+
segment_base_offset + data_segment->offset(),
182+
data_segment->size(),
183+
DataLoader::SegmentInfo(
184+
DataLoader::SegmentInfo::Type::Constant,
185+
constant_segment->segment_index()));
178186
if (!constant_segment_data.ok()) {
179187
return constant_segment_data.error();
180188
}
@@ -423,8 +431,12 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
423431
// Could fail if offset and size are out of bound for the data, or if this
424432
// is reading from a file and fails, or for many other reasons depending on
425433
// the implementation of the loader.
426-
return loader_->Load(
427-
segment_base_offset_ + segment->offset(), segment->size());
434+
// TODO(jackzhxng): "backend_segment" is a hardcode, pass in real backend id.
435+
return loader_->load(
436+
segment_base_offset_ + segment->offset(),
437+
segment->size(),
438+
DataLoader::SegmentInfo(
439+
DataLoader::SegmentInfo::Type::Backend, index, "backend_segment"));
428440
}
429441

430442
} // namespace executor

runtime/executor/test/backend_integration_test.cpp

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
168168
public:
169169
/// A record of an operation performed on this DataLoader.
170170
struct Operation {
171-
enum { Load, Free } op;
172-
size_t offset; // Set for Load; zero for Free.
173-
void* data; // Set for Free; nullptr for Load.
174-
size_t size; // Set for Load and Free.
171+
enum { Load, Free, DeprecatedLoad } op;
172+
size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
173+
void* data; // Set for Free; nullptr for Load/DeprecatedLoad.
174+
size_t size; // Set for Load/DeprecatedLoad and Free.
175+
std::unique_ptr<const DataLoader::SegmentInfo>
176+
segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
175177
};
176178

177179
explicit DataLoaderSpy(DataLoader* delegate) : delegate_(delegate) {}
178180

181+
/**
182+
* Override the deprecated "Load" method. We will be looking to test that
183+
* this function is not called if the new "load" method is called.
184+
*/
179185
Result<FreeableBuffer> Load(size_t offset, size_t size) override {
180186
Result<FreeableBuffer> buf = delegate_->Load(offset, size);
181187
if (!buf.ok()) {
182188
return buf.error();
183189
}
184-
operations_.push_back({Operation::Load, offset, /*data=*/nullptr, size});
190+
operations_.push_back(
191+
{Operation::DeprecatedLoad,
192+
offset,
193+
/*data=*/nullptr,
194+
size,
195+
/*segment_info=*/nullptr});
196+
auto* context = new SpyContext(&operations_, std::move(buf.get()));
197+
// Use context->buffer since buf has been moved.
198+
return FreeableBuffer(
199+
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
200+
}
201+
202+
Result<FreeableBuffer>
203+
load(size_t offset, size_t size, const SegmentInfo& segment_info) override {
204+
Result<FreeableBuffer> buf = delegate_->load(offset, size, segment_info);
205+
if (!buf.ok()) {
206+
return buf.error();
207+
}
208+
209+
auto segment_info_cpy =
210+
std::make_unique<const DataLoader::SegmentInfo>(segment_info);
211+
operations_.push_back(
212+
{Operation::Load,
213+
offset,
214+
/*data=*/nullptr,
215+
size,
216+
/*segment_info=*/std::move(segment_info_cpy)});
185217
auto* context = new SpyContext(&operations_, std::move(buf.get()));
186218
// Use context->buffer since buf has been moved.
187219
return FreeableBuffer(
@@ -200,6 +232,36 @@ class DataLoaderSpy : public DataLoader {
200232
return operations_;
201233
}
202234

235+
/**
236+
* Returns true if the DataLoader::load() method was called with the correct
237+
* segment info.
238+
*/
239+
bool UsedLoad(
240+
DataLoader::SegmentInfo::Type segment_type,
241+
const char* descriptor = nullptr) const {
242+
for (const auto& op : operations_) {
243+
// We should not be using the deprecated DataLoader::Load() function.
244+
if (op.op == Operation::DeprecatedLoad) {
245+
return false;
246+
}
247+
if (op.op != Operation::Load) {
248+
continue;
249+
}
250+
// We have a load op.
251+
if (op.segment_info->segment_type == segment_type) {
252+
if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
253+
// For non-backend segments, the descriptor is irrelevant / a nullptr.
254+
return true;
255+
} else {
256+
if (strcmp(op.segment_info->descriptor, descriptor) == 0) {
257+
return true;
258+
}
259+
}
260+
}
261+
}
262+
return false;
263+
}
264+
203265
/**
204266
* Returns true if the operations list shows that the provided data pointer
205267
* was freed.
@@ -223,7 +285,8 @@ class DataLoaderSpy : public DataLoader {
223285

224286
static void FreeBuffer(void* context, void* data, size_t size) {
225287
auto* sc = reinterpret_cast<SpyContext*>(context);
226-
sc->operations->push_back({Operation::Free, /*offset=*/0, data, size});
288+
sc->operations->push_back(
289+
{Operation::Free, /*offset=*/0, data, size, /*segment_info=*/nullptr});
227290
delete sc;
228291
}
229292

@@ -333,7 +396,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
333396
EXPECT_EQ(method_res.error(), Error::Ok);
334397

335398
// Demonstrate that our installed init was called.
336-
EXPECT_EQ(init_called, true);
399+
EXPECT_TRUE(init_called);
337400

338401
// See if the processed data was freed.
339402
bool processed_was_freed = spy_loader.WasFreed(processed_data);
@@ -444,6 +507,51 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
444507
EXPECT_EQ(execute_handle, destroy_handle);
445508
}
446509

510+
/**
511+
* Tests that the DataLoader's load is receiving the correct segment info for
512+
* different types of segments.
513+
*/
514+
TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
515+
const void* processed_data = nullptr;
516+
StubBackend::singleton().install_init(
517+
[&](FreeableBuffer* processed,
518+
__ET_UNUSED ArrayRef<CompileSpec> compile_specs,
519+
__ET_UNUSED MemoryAllocator* runtime_allocator)
520+
-> Result<DelegateHandle*> {
521+
processed_data = processed->data();
522+
processed->Free();
523+
return nullptr;
524+
});
525+
526+
// Wrap the real loader in a spy so we can see which operations were
527+
// performed.
528+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
529+
ASSERT_EQ(loader.error(), Error::Ok);
530+
DataLoaderSpy spy_loader(&loader.get());
531+
532+
// Load the program.
533+
Result<Program> program = Program::load(&spy_loader);
534+
ASSERT_EQ(program.error(), Error::Ok);
535+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
536+
537+
// Expect that load was called correctly on program segments.
538+
bool program_load_was_called =
539+
spy_loader.UsedLoad(DataLoader::SegmentInfo::Type::Program, nullptr);
540+
541+
// Load a method.
542+
Result<Method> method_res = program->load_method("forward", &mmm.get());
543+
EXPECT_EQ(method_res.error(), Error::Ok);
544+
545+
// Expect that load was called correctly on a backend segment.
546+
bool backend_load_was_called = spy_loader.UsedLoad(
547+
DataLoader::SegmentInfo::Type::Backend,
548+
"backend_segment"); // TODO(jackzhxng): replace with actual mock PTE
549+
// file's backend_id in next chained PR.
550+
551+
EXPECT_TRUE(program_load_was_called);
552+
EXPECT_EQ(backend_load_was_called, using_segments());
553+
}
554+
447555
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
448556
// - Errors during init() or execute() result in runtime init/execution failures
449557
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)