Skip to content

Switch to DataLoader::load in runtime #4217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions runtime/core/data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,57 @@ namespace executor {
*/
class DataLoader {
public:
/**
* Describes the content of the segment.
*/
struct SegmentInfo {
/**
* Represents the purpose of the segment.
*/
enum class Type {
/**
* Data for the actual program.
*/
Program,
/**
* Holds constant tensor data.
*/
Constant,
/**
* Data used for initializing a backend.
*/
Backend,
};

/// Type of the segment.
Type segment_type;

/// Index of the segment within the segment list. Undefined for program
/// segments.
size_t segment_index;

/// An optional, null-terminated string describing the segment. For
/// `Backend` segments, this is the backend ID. Null for other segment
/// types.
const char* descriptor;

SegmentInfo() = default;

explicit SegmentInfo(
Type segment_type,
size_t segment_index = 0,
const char* descriptor = nullptr)
: segment_type(segment_type),
segment_index(segment_index),
descriptor(descriptor) {}
};

virtual ~DataLoader() = default;

/**
* DEPRECATED: Use `load()` going forward for access to segment info during
* the load.
*
* Loads `size` bytes at byte offset `offset` from the underlying data source
* into a `FreeableBuffer`, which owns the memory.
*
Expand All @@ -37,6 +85,24 @@ class DataLoader {
size_t offset,
size_t size) = 0;

/**
* Loads data from the underlying data source.
*
* NOTE: This must be thread-safe. If this call modifies common state, the
* implementation must do its own locking.
*
* @param offset The byte offset in the data source to start loading from.
* @param size The number of bytes to load.
* @param segment_info Information about the segment being loaded.
*
* @returns a `FreeableBuffer` that owns the loaded data.
*/
__ET_NODISCARD virtual Result<FreeableBuffer>
load(size_t offset, size_t size, const SegmentInfo& segment_info) {
(void)segment_info;
return Load(offset, size); // NOLINT(facebook-hte-Deprecated)
}

/**
* Returns the length of the underlying data source, typically the file size.
*/
Expand Down
28 changes: 20 additions & 8 deletions runtime/executor/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
size_t segment_base_offset = 0;
{
EXECUTORCH_SCOPE_PROF("Program::check_header");
Result<FreeableBuffer> header =
loader->Load(/*offset=*/0, ExtendedHeader::kNumHeadBytes);
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
ExtendedHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
if (!header.ok()) {
return header.error();
}
Expand All @@ -95,8 +97,10 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(

// Load the flatbuffer data as a segment.
uint32_t prof_tok = EXECUTORCH_BEGIN_PROF("Program::load_data");
Result<FreeableBuffer> program_data =
loader->Load(/*offset=*/0, program_size);
Result<FreeableBuffer> program_data = loader->load(
/*offset=*/0,
program_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
if (!program_data.ok()) {
return program_data.error();
}
Expand Down Expand Up @@ -173,8 +177,12 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(

const executorch_flatbuffer::DataSegment* data_segment =
segments->Get(constant_segment->segment_index());
Result<FreeableBuffer> constant_segment_data = loader->Load(
segment_base_offset + data_segment->offset(), data_segment->size());
Result<FreeableBuffer> constant_segment_data = loader->load(
segment_base_offset + data_segment->offset(),
data_segment->size(),
DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Constant,
constant_segment->segment_index()));
if (!constant_segment_data.ok()) {
return constant_segment_data.error();
}
Expand Down Expand Up @@ -423,8 +431,12 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
// Could fail if offset and size are out of bound for the data, or if this
// is reading from a file and fails, or for many other reasons depending on
// the implementation of the loader.
return loader_->Load(
segment_base_offset_ + segment->offset(), segment->size());
// TODO(jackzhxng): "backend_segment" is a hardcode, pass in real backend id.
return loader_->load(
segment_base_offset_ + segment->offset(),
segment->size(),
DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Backend, index, "backend_segment"));
}

} // namespace executor
Expand Down
122 changes: 115 additions & 7 deletions runtime/executor/test/backend_integration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
public:
/// A record of an operation performed on this DataLoader.
struct Operation {
enum { Load, Free } op;
size_t offset; // Set for Load; zero for Free.
void* data; // Set for Free; nullptr for Load.
size_t size; // Set for Load and Free.
enum { Load, Free, DeprecatedLoad } op;
size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
void* data; // Set for Free; nullptr for Load/DeprecatedLoad.
size_t size; // Set for Load/DeprecatedLoad and Free.
std::unique_ptr<const DataLoader::SegmentInfo>
segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
};

explicit DataLoaderSpy(DataLoader* delegate) : delegate_(delegate) {}

/**
* Override the deprecated "Load" method. We will be looking to test that
* this function is not called if the new "load" method is called.
*/
Result<FreeableBuffer> Load(size_t offset, size_t size) override {
Result<FreeableBuffer> buf = delegate_->Load(offset, size);
if (!buf.ok()) {
return buf.error();
}
operations_.push_back({Operation::Load, offset, /*data=*/nullptr, size});
operations_.push_back(
{Operation::DeprecatedLoad,
offset,
/*data=*/nullptr,
size,
/*segment_info=*/nullptr});
auto* context = new SpyContext(&operations_, std::move(buf.get()));
// Use context->buffer since buf has been moved.
return FreeableBuffer(
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
}

Result<FreeableBuffer>
load(size_t offset, size_t size, const SegmentInfo& segment_info) override {
Result<FreeableBuffer> buf = delegate_->load(offset, size, segment_info);
if (!buf.ok()) {
return buf.error();
}

auto segment_info_cpy =
std::make_unique<const DataLoader::SegmentInfo>(segment_info);
operations_.push_back(
{Operation::Load,
offset,
/*data=*/nullptr,
size,
/*segment_info=*/std::move(segment_info_cpy)});
auto* context = new SpyContext(&operations_, std::move(buf.get()));
// Use context->buffer since buf has been moved.
return FreeableBuffer(
Expand All @@ -200,6 +232,36 @@ class DataLoaderSpy : public DataLoader {
return operations_;
}

/**
* Returns true if the DataLoader::load() method was called with the correct
* segment info.
*/
bool UsedLoad(
DataLoader::SegmentInfo::Type segment_type,
const char* descriptor = nullptr) const {
for (const auto& op : operations_) {
// We should not be using the deprecated DataLoader::Load() function.
if (op.op == Operation::DeprecatedLoad) {
return false;
}
if (op.op != Operation::Load) {
continue;
}
// We have a load op.
if (op.segment_info->segment_type == segment_type) {
if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
// For non-backend segments, the descriptor is irrelevant / a nullptr.
return true;
} else {
if (strcmp(op.segment_info->descriptor, descriptor) == 0) {
return true;
}
}
}
}
return false;
}

/**
* Returns true if the operations list shows that the provided data pointer
* was freed.
Expand All @@ -223,7 +285,8 @@ class DataLoaderSpy : public DataLoader {

static void FreeBuffer(void* context, void* data, size_t size) {
auto* sc = reinterpret_cast<SpyContext*>(context);
sc->operations->push_back({Operation::Free, /*offset=*/0, data, size});
sc->operations->push_back(
{Operation::Free, /*offset=*/0, data, size, /*segment_info=*/nullptr});
delete sc;
}

Expand Down Expand Up @@ -333,7 +396,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
EXPECT_EQ(method_res.error(), Error::Ok);

// Demonstrate that our installed init was called.
EXPECT_EQ(init_called, true);
EXPECT_TRUE(init_called);

// See if the processed data was freed.
bool processed_was_freed = spy_loader.WasFreed(processed_data);
Expand Down Expand Up @@ -444,6 +507,51 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
EXPECT_EQ(execute_handle, destroy_handle);
}

/**
* Tests that the DataLoader's load is receiving the correct segment info for
* different types of segments.
*/
TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
const void* processed_data = nullptr;
StubBackend::singleton().install_init(
[&](FreeableBuffer* processed,
__ET_UNUSED ArrayRef<CompileSpec> compile_specs,
__ET_UNUSED MemoryAllocator* runtime_allocator)
-> Result<DelegateHandle*> {
processed_data = processed->data();
processed->Free();
return nullptr;
});

// Wrap the real loader in a spy so we can see which operations were
// performed.
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
ASSERT_EQ(loader.error(), Error::Ok);
DataLoaderSpy spy_loader(&loader.get());

// Load the program.
Result<Program> program = Program::load(&spy_loader);
ASSERT_EQ(program.error(), Error::Ok);
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);

// Expect that load was called correctly on program segments.
bool program_load_was_called =
spy_loader.UsedLoad(DataLoader::SegmentInfo::Type::Program, nullptr);

// Load a method.
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_EQ(method_res.error(), Error::Ok);

// Expect that load was called correctly on a backend segment.
bool backend_load_was_called = spy_loader.UsedLoad(
DataLoader::SegmentInfo::Type::Backend,
"backend_segment"); // TODO(jackzhxng): replace with actual mock PTE
// file's backend_id in next chained PR.

EXPECT_TRUE(program_load_was_called);
EXPECT_EQ(backend_load_was_called, using_segments());
}

// TODO: Add more tests for the runtime-to-backend interface. E.g.:
// - Errors during init() or execute() result in runtime init/execution failures
// - Correct values are passed to init()/execute()
Expand Down
Loading