Skip to content

Support mutable tensors in TensorParser #4713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions runtime/executor/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ namespace testing {
class ProgramTestFriend;
} // namespace testing

namespace deserialization {
// Provides Tensor deserializaiton access to private Program methods.
class TensorParser;
} // namespace deserialization

/**
* A deserialized ExecuTorch program binary.
*/
Expand Down Expand Up @@ -194,6 +199,7 @@ class Program final {
friend class BackendDelegate;
friend class Executor;
friend class Method;
friend class deserialization::TensorParser;
friend class testing::ProgramTestFriend;

const executorch_flatbuffer::Program* get_internal_program() const {
Expand Down
31 changes: 24 additions & 7 deletions runtime/executor/tensor_parser_exec_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ namespace torch {
namespace executor {
namespace deserialization {

// Provides access to private Program methods.
class TensorParser final {
public:
__ET_NODISCARD static Error load_mutable_subsegment_into(
const Program* program,
size_t mutable_data_segments_index,
size_t offset_index,
size_t size,
void* buffer) {
return program->load_mutable_subsegment_into(
mutable_data_segments_index, offset_index, size, buffer);
}
};

namespace {

// Retrieve the buffer specified by the allocation_info
Expand Down Expand Up @@ -94,14 +108,17 @@ __ET_NODISCARD Result<void*> getTensorDataPtr(

// Memory Planned, with initial state
if (data_buffer_idx > 0 && allocation_info != nullptr) {
// Stub case for now.

// Get memory planned data pointer

// Call something like program.load_into_buffer(s_tensor->segment_idx,
// s_tensor->data_buffer_idx, mem_planned_buffer, nbytes)
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
if (!planned_ptr.ok()) {
return planned_ptr.error();
}
auto err = TensorParser::load_mutable_subsegment_into(
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());

return Error::NotImplemented;
if (err != Error::Ok) {
return err;
}
return planned_ptr;

// Constant
} else if (data_buffer_idx > 0 && allocation_info == nullptr) {
Expand Down
69 changes: 69 additions & 0 deletions runtime/executor/test/tensor_parser_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,72 @@ TEST_F(TensorParserTest, TestModuleAddHalf) {
torch::executor::ScalarType::Half,
sizeof(torch::executor::Half));
}

TEST_F(TensorParserTest, TestMutableState) {
// Load the serialized ModuleSimpleTrain data.
const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
Result<FileDataLoader> train_loader = FileDataLoader::from(path);
ASSERT_EQ(train_loader.error(), Error::Ok);

Result<Program> program =
Program::load(&train_loader.get(), Program::Verification::Minimal);
EXPECT_EQ(program.error(), Error::Ok);

ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
ManagedMemoryManager mmm_copy(
kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);

const executorch_flatbuffer::Program* internal_program =
ProgramTestFriend::GetInternalProgram(&program.get());
executorch_flatbuffer::ExecutionPlan* execution_plan =
internal_program->execution_plan()->GetMutableObject(0);
auto flatbuffer_values = execution_plan->values();

size_t num_mutable_tensors = 0;
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
auto serialization_value = flatbuffer_values->Get(i);
if (serialization_value->val_type() ==
executorch_flatbuffer::KernelTypes::Tensor &&
serialization_value->val_as_Tensor()->allocation_info() != nullptr &&
serialization_value->val_as_Tensor()->data_buffer_idx() > 0) {
num_mutable_tensors++;
Result<torch::executor::Tensor> tensor = parseTensor(
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
torch::executor::Tensor t = tensor.get();
float loaded_value = t.const_data_ptr<float>()[0];
ASSERT_NE(nullptr, t.const_data_ptr());
ASSERT_NE(t.mutable_data_ptr<float>()[0], 0.5);
t.mutable_data_ptr<float>()[0] = 0.5;
ASSERT_EQ(
t.mutable_data_ptr<float>()[0],
0.5); // 0.5 can be represented perfectly by float so EQ and NE work
// fine here. Any power of 2 rational can be perfectly
// represented. See dyadic rationals for more info.

// Load the same tensor using the same mem manager and show the value is
// updated again.
Result<torch::executor::Tensor> tensor1_alias = parseTensor(
&program.get(), &mmm.get(), serialization_value->val_as_Tensor());
torch::executor::Tensor t2 = tensor.get();
ASSERT_NE(t2.mutable_data_ptr<float>()[0], 0.5);

// Show the tensors are equivalent
ASSERT_EQ(t.const_data_ptr(), t2.const_data_ptr());
// Set mutable tensor value back to 0.5 since it got overwritten by second
// parse.
t.mutable_data_ptr<float>()[0] = 0.5;

// Load the same tensor using a different mem manager and show the value
// is not the same as t.
Result<torch::executor::Tensor> tensor_new = parseTensor(
&program.get(),
&mmm_copy.get(),
serialization_value->val_as_Tensor());
torch::executor::Tensor t3 = tensor_new.get();
ASSERT_NE(t3.mutable_data_ptr<float>()[0], 0.5);
ASSERT_NE(t3.const_data_ptr(), t.const_data_ptr());
ASSERT_EQ(loaded_value, t3.const_data_ptr<float>()[0]);
}
}
ASSERT_EQ(num_mutable_tensors, 2);
}
4 changes: 3 additions & 1 deletion test/run_oss_cpp_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ build_gtest() {
}

export_test_model() {
python3 -m test.models.export_program --modules "ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry" --outdir "cmake-out" 2> /dev/null
python3 -m test.models.export_program --modules "ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry,ModuleSimpleTrain" --outdir "cmake-out" 2> /dev/null
python3 -m test.models.export_delegated_program --modules "ModuleAddMul" --backend_id "StubBackend" --outdir "cmake-out" || true

ET_MODULE_ADD_HALF_PATH="$(realpath cmake-out/ModuleAddHalf.pte)"
Expand All @@ -65,6 +65,7 @@ export_test_model() {
ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH="$(realpath cmake-out/ModuleAddMul-nosegments-da1024.pte)"
ET_MODULE_ADD_MUL_NOSEGMENTS_PATH="$(realpath cmake-out/ModuleAddMul-nosegments.pte)"
ET_MODULE_ADD_MUL_PATH="$(realpath cmake-out/ModuleAddMul.pte)"
ET_MODULE_SIMPLE_TRAIN_PATH="$(realpath cmake-out/ModuleSimpleTrain.pte)"
export ET_MODULE_ADD_HALF_PATH
export ET_MODULE_ADD_PATH
export ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH
Expand All @@ -75,6 +76,7 @@ export_test_model() {
export ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH
export ET_MODULE_ADD_MUL_NOSEGMENTS_PATH
export ET_MODULE_ADD_MUL_PATH
export ET_MODULE_SIMPLE_TRAIN_PATH
}

build_and_run_test() {
Expand Down
Loading