Skip to content

Constant segment runtime tests #1505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions runtime/executor/test/method_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class MethodTest : public ::testing::Test {
load_program(std::getenv("ET_MODULE_INDEX_PATH"), "index");
load_program(
std::getenv("ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH"), "cat");
load_program(
std::getenv("ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH"),
"linear_constant_segment");
load_program(
std::getenv("ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"),
"linear_constant_buffer");
}

private:
Expand Down Expand Up @@ -196,6 +202,30 @@ TEST_F(MethodTest, AliasedIOTest) {
}
}

TEST_F(MethodTest, ConstantSegmentTest) {
// Execute model with constants stored in segment.
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method =
programs_["linear_constant_segment"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

// Can execute the method.
Error err = method->execute();
ASSERT_EQ(err, Error::Ok);
}

TEST_F(MethodTest, ConstantBufferTest) {
// Execute model with constants stored in the program flatbuffer.
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method =
programs_["linear_constant_buffer"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

// Can execute the method.
Error err = method->execute();
ASSERT_EQ(err, Error::Ok);
}

// TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
// the portable op lib

Expand Down
76 changes: 73 additions & 3 deletions runtime/executor/test/program_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/schema/program_generated.h>
#include <executorch/test/utils/DeathTest.h>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -89,6 +90,11 @@ class ProgramTestFriend final {
size_t index) {
return program->LoadSegment(index);
}

const static executorch_flatbuffer::Program* GetInternalProgram(
const Program* program) {
return program->internal_program_;
}
};
} // namespace testing
} // namespace executor
Expand Down Expand Up @@ -299,9 +305,6 @@ TEST_F(ProgramTest, HeaderNotPresent) {
Program::HeaderStatus::NotPresent);
}

// TODO(T144120904): Add tests for programs with segments once we can generate
// them.

TEST_F(ProgramTest, getMethods) {
// Parse the Program from the data.
Result<Program> program_res =
Expand All @@ -326,3 +329,70 @@ TEST_F(ProgramTest, DEPRECATEDLoad) {
Result<Program> program_res = Program::Load(multi_loader_.get());
EXPECT_EQ(program_res.error(), Error::Ok);
}

TEST_F(ProgramTest, LoadConstantSegment) {
// Load the serialized ModuleLinear data, with constants in the segment and no
// constants in the flatbuffer.
const char* linear_path =
std::getenv("ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH");
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
ASSERT_EQ(linear_loader.error(), Error::Ok);

// This file should always be compatible.
Result<FreeableBuffer> linear_header =
linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes);
ASSERT_EQ(linear_header.error(), Error::Ok);
EXPECT_EQ(
Program::check_header(linear_header->data(), linear_header->size()),
Program::HeaderStatus::CompatibleVersion);

Result<Program> program = Program::load(&linear_loader.get());
ASSERT_EQ(program.error(), Error::Ok);

// Load constant segment data.
Result<FreeableBuffer> segment =
ProgramTestFriend::LoadSegment(&program.get(), 0);
EXPECT_EQ(segment.error(), Error::Ok);

const executorch_flatbuffer::Program* flatbuffer_program =
ProgramTestFriend::GetInternalProgram(&program.get());

// Expect one segment containing the constants.
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);

// The constant buffer should be empty.
EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);

// Check constant segment offsets.
EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0);
EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1);
}

TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
// Load the serialized ModuleLinear data, with constants in the flatbuffer and
// no constants in the segment.
const char* linear_path =
std::getenv("ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH");
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
ASSERT_EQ(linear_loader.error(), Error::Ok);

// This file should always be compatible.
Result<FreeableBuffer> linear_header =
linear_loader->Load(/*offset=*/0, Program::kMinHeadBytes);
ASSERT_EQ(linear_header.error(), Error::Ok);
EXPECT_EQ(
Program::check_header(linear_header->data(), linear_header->size()),
Program::HeaderStatus::CompatibleVersion);

Result<Program> program = Program::load(&linear_loader.get());
ASSERT_EQ(program.error(), Error::Ok);

const executorch_flatbuffer::Program* flatbuffer_program =
ProgramTestFriend::GetInternalProgram(&program.get());

// Expect no segments.
EXPECT_EQ(flatbuffer_program->segments()->size(), 0);

// The constant buffer should exist.
EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
}
3 changes: 3 additions & 0 deletions runtime/executor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def define_common_targets(is_fbcode = False):
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
"ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleDynamicCatUnallocatedIO.pte])",
"ET_MODULE_INDEX_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleIndex.pte])",
"ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear-no-constant-segment.pte])",
"ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])",
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
}

Expand Down Expand Up @@ -140,6 +142,7 @@ def define_common_targets(is_fbcode = False):
"//executorch/runtime/executor:program",
"//executorch/extension/data_loader:buffer_data_loader",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/schema:program",
],
env = modules_env,
)
Expand Down