Skip to content

Expand Program Interface #4680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions runtime/executor/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,90 @@ Result<FreeableBuffer> Program::LoadSegment(
segment_base_offset_ + segment->offset(), segment->size(), segment_info);
}

Error Program::load_mutable_subsegment_into(
size_t mutable_data_segments_index,
size_t offset_index,
size_t size,
void* buffer) const {
EXECUTORCH_SCOPE_PROF("Program::load_subsegment_into");
// Check that the program has segments.
if (loader_ == nullptr || segment_base_offset_ == 0) {
ET_LOG(Error, "No segments in program");
return Error::NotFound;
}

// Check that the program has mutable data segments.
if (internal_program_->mutable_data_segments() == nullptr) {
ET_LOG(Error, "No mutable data segments in program");
return Error::NotFound;
}
if (mutable_data_segments_index >=
internal_program_->mutable_data_segments()->size()) {
ET_LOG(
Error,
"mutable_data_segments_index %zu out of range >= %" PRIu64,
mutable_data_segments_index,
(uint64_t)internal_program_->mutable_data_segments()->size());
return Error::NotFound;
}

// Grab the mutable data segment info.
const auto& segment_offsets = internal_program_->mutable_data_segments()->Get(
mutable_data_segments_index);

// Check that the offset is valid.
if (segment_offsets->offsets() == nullptr) {
ET_LOG(Error, "No offsets in mutable data segment");
return Error::NotFound;
}
if (offset_index >= segment_offsets->offsets()->size()) {
ET_LOG(
Error,
"offset index %zu out of range >= %" PRIu64,
offset_index,
(uint64_t)segment_offsets->offsets()->size());
return Error::NotFound;
}

// Grab the offset. Note: This offset is relative to the start of the segment,
// so we will need to adjust when calling the loader.
size_t offset = segment_offsets->offsets()->Get(offset_index);

// Grab the segment index
size_t num_segments = internal_program_->segments()->size();
if (segment_offsets->segment_index() >= num_segments) {
ET_LOG(
Error,
"Segment index %u out of range (>= %zu)",
segment_offsets->segment_index(),
num_segments);
return Error::NotFound;
}

// Grab the segment
auto segment =
internal_program_->segments()->Get(segment_offsets->segment_index());

// Check size
if (offset + size > segment->size()) {
ET_LOG(
Error,
"offset %zu + size %zu out of range > %" PRIu64,
offset,
size,
segment->size());
return Error::InvalidArgument;
}

DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Mutable,
segment_offsets->segment_index(),
nullptr);

// Load the data
return loader_->load_into(
segment_base_offset_ + segment->offset() + offset, size, info, buffer);
}

} // namespace executor
} // namespace torch
24 changes: 24 additions & 0 deletions runtime/executor/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,30 @@ class Program final {
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(
const DataLoader::SegmentInfo& segment_info) const;

/**
* Loads a portion of a mutable segment into the provided buffer.
*
* @param[in] mutable_data_segments_index The index into the
* mutable_data_segments_array.
* @param[in] offset_index The index into the segment's offsets array.
* @param[in] size The number of bytes to load.
* @param[in] buffer The buffer to load data into. Must point to at least
* `size` bytes of memory.
*
* @returns An error code on if the load was successful.
* @retval Error::Ok The load was successful.
* @retval Error::NotFound The program does not contain any segments or the
* indices are out of range.
* @returns Other errors depending on the implementation of
* DataLoader: The Program.segment table is inconsistent, or the
* data cannot be accessed.
*/
__ET_NODISCARD Error load_mutable_subsegment_into(
size_t mutable_data_segments_index,
size_t offset_index,
size_t size,
void* buffer) const;

private:
Program(
DataLoader* loader,
Expand Down
98 changes: 97 additions & 1 deletion runtime/executor/test/program_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ProgramTest : public ::testing::Test {

add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));

// Load the serialized ModuleAdd data.
// Load the serialized ModuleMultiEntry data.
path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH");
Result<FileDataLoader> multi_loader = FileDataLoader::from(path);
ASSERT_EQ(multi_loader.error(), Error::Ok);
Expand Down Expand Up @@ -98,6 +98,16 @@ class ProgramTestFriend final {
return program->LoadSegment(segment_info);
}

__ET_NODISCARD static Error load_mutable_subsegment_into(
const Program* program,
size_t mutable_data_segments_index,
size_t offset_index,
size_t size,
void* buffer) {
return program->load_mutable_subsegment_into(
mutable_data_segments_index, offset_index, size, buffer);
}

const static executorch_flatbuffer::Program* GetInternalProgram(
const Program* program) {
return program->internal_program_;
Expand Down Expand Up @@ -444,3 +454,89 @@ TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
// The constant buffer should exist.
EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
}

TEST_F(ProgramTest, LoadFromMutableSegment) {
// Load the serialized ModuleSimpleTrain data.
auto path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
Result<FileDataLoader> training_loader = FileDataLoader::from(path);
ASSERT_EQ(training_loader.error(), Error::Ok);

// This file should always be compatible.
Result<FreeableBuffer> training_header = training_loader->load(
/*offset=*/0,
Program::kMinHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
ASSERT_EQ(training_header.error(), Error::Ok);
EXPECT_EQ(
Program::check_header(training_header->data(), training_header->size()),
Program::HeaderStatus::CompatibleVersion);

Result<Program> program = Program::load(&training_loader.get());
ASSERT_EQ(program.error(), Error::Ok);

// dummy buffers to load into
uint8_t buffer[1] = {0};
uint8_t buffer2[1] = {0};

// Load some mutable segment data
Error err = ProgramTestFriend::load_mutable_subsegment_into(
&program.get(), 0, 1, 1, buffer);
EXPECT_EQ(err, Error::Ok);

// Check that the data loaded correctly, and then mutate it
EXPECT_EQ(buffer[0], 232); // 232 comes from inspecting the file itself. The
// file is seeded so this value should be stable.
buffer[0] = 0;

// Load the same mutable segment data from file into a different buffer.
err = ProgramTestFriend::load_mutable_subsegment_into(
&program.get(),
0, // mutable_data_segments_index
1, // offset_index
1, // size
buffer2);
EXPECT_EQ(err, Error::Ok);

// Check that new data loaded from the file does not reflect the change to
// buffer.
EXPECT_EQ(buffer2[0], 232);

const executorch_flatbuffer::Program* flatbuffer_program =
ProgramTestFriend::GetInternalProgram(&program.get());

// Expect 1 segment. 1 mutable segment and no constant segment.
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);

// Expect a mutable data segment.
EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);

// Expect the 0 index to be reserved and the offsets for weight and bias of
// linear to be indices 1 and 2.
EXPECT_EQ(
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->size(),
3);
EXPECT_EQ(
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(0),
0);
EXPECT_EQ(
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(1),
0);
EXPECT_EQ(
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(2),
36);

// Loading beyond file should fail
err = ProgramTestFriend::load_mutable_subsegment_into(
&program.get(), 0, 1, 500, buffer);
EXPECT_NE(err, Error::Ok);

// Loading beyond offsets should fail
err = ProgramTestFriend::load_mutable_subsegment_into(
&program.get(), 0, 500, 1, buffer);
EXPECT_NE(err, Error::Ok);

// Loading beyond segments should fail
err = ProgramTestFriend::load_mutable_subsegment_into(
&program.get(), 500, 1, 1, buffer);
EXPECT_NE(err, Error::Ok);
}
1 change: 1 addition & 0 deletions runtime/executor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def define_common_targets(is_fbcode = False):
"ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear-no-constant-segment.pte])",
"ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])",
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
}

runtime.cxx_test(
Expand Down
Loading