Skip to content

Add API to set inputs independently from execution. #5356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ runtime::Error Module::load_method(
temp_allocator_.get());
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
method_name.c_str(), method_holder.memory_manager.get(), tracer));
method_holder.inputs.resize(method_holder.method->inputs_size());
methods_.emplace(method_name, std::move(method_holder));
}
return runtime::Error::Ok;
Expand All @@ -170,10 +171,19 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
const std::vector<runtime::EValue>& input_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;
auto& inputs = methods_.at(method_name).inputs;

ET_CHECK_OK_OR_RETURN_ERROR(
method->set_inputs(exec_aten::ArrayRef<runtime::EValue>(
input_values.data(), input_values.size())));
for (size_t i = 0; i < input_values.size(); ++i) {
if (!input_values[i].isNone()) {
inputs[i] = input_values[i];
}
}
for (size_t i = 0; i < inputs.size(); ++i) {
ET_CHECK_OR_RETURN_ERROR(
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
}
ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());

const auto outputs_size = method->outputs_size();
Expand All @@ -184,6 +194,30 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
return outputs;
}

runtime::Error Module::set_input(
const std::string& method_name,
const runtime::EValue& input_value,
size_t input_index) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
methods_.at(method_name).inputs.at(input_index) = input_value;
return runtime::Error::Ok;
}

runtime::Error Module::set_inputs(
const std::string& method_name,
const std::vector<runtime::EValue>& input_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& inputs = methods_.at(method_name).inputs;
ET_CHECK_OR_RETURN_ERROR(
inputs.size() == input_values.size(),
InvalidArgument,
"input size: %zu does not match method input size: %zu",
input_values.size(),
inputs.size());
inputs = input_values;
return runtime::Error::Ok;
}

runtime::Error Module::set_output_data_ptr(
runtime::EValue output_value,
size_t output_index,
Expand Down
57 changes: 57 additions & 0 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,62 @@ class Module {
return forward(std::vector<runtime::EValue>{});
}

/**
* Sets a single input value for a specific method.
*
* @param[in] method_name The name of the method.
* @param[in] input_value The EValue to set as the method input.
* @param[in] input_index Zero-based index of the input to set.
*
* @returns An Error to indicate success or failure.
*/
ET_NODISCARD
runtime::Error set_input(
const std::string& method_name,
const runtime::EValue& input_value,
size_t input_index);

/**
* Sets a single input value for the "forward" method.
*
* @param[in] input_value The EValue to set as the method input.
* @param[in] input_index Zero-based index of the input to set.
*
* @returns An Error to indicate success or failure.
*/
ET_NODISCARD
inline runtime::Error set_input(
const runtime::EValue& input_value,
size_t input_index) {
return set_input("forward", input_value, input_index);
}

/**
* Sets all input values for a specific method.
*
* @param[in] method_name The name of the method.
* @param[in] input_values A vector of EValues to set as the method inputs.
*
* @returns An Error to indicate success or failure.
*/
ET_NODISCARD
runtime::Error set_inputs(
const std::string& method_name,
const std::vector<runtime::EValue>& input_values);

/**
* Sets all input values for the "forward" method.
*
* @param[in] input_values A vector of EValues to set as the method inputs.
*
* @returns An Error to indicate success or failure.
*/
ET_NODISCARD
inline runtime::Error set_inputs(
const std::vector<runtime::EValue>& input_values) {
return set_inputs("forward", input_values);
}

/**
* Retrieves the EventTracer instance being used by the Module.
* EventTracer is used for tracking and logging events during the execution
Expand Down Expand Up @@ -332,6 +388,7 @@ class Module {
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
std::unique_ptr<runtime::MemoryManager> memory_manager;
std::unique_ptr<runtime::Method> method;
std::vector<runtime::EValue> inputs;
};

private:
Expand Down
48 changes: 48 additions & 0 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,51 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
t4.join();
t5.join();
}

TEST_F(ModuleTest, TestSetInputsBeforeExecute) {
Module module(model_path_);

auto tensor1 = make_tensor_ptr({4.f});
auto tensor2 = make_tensor_ptr({5.f});

EXPECT_EQ(module.set_inputs({tensor1, tensor2}), Error::Ok);

const auto result = module.forward();
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();
EXPECT_NEAR(data[0], 9, 1e-5);
}

TEST_F(ModuleTest, TestSetInputCombinedWithExecute) {
Module module(model_path_);

auto tensor1 = make_tensor_ptr({2.f});
auto tensor2 = make_tensor_ptr({3.f});

EXPECT_EQ(module.set_input(tensor2, 1), Error::Ok);

const auto result = module.forward(tensor1);
EXPECT_EQ(result.error(), Error::Ok);

const auto data = result->at(0).toTensor().const_data_ptr<float>();
EXPECT_NEAR(data[0], 5, 1e-5);
}

TEST_F(ModuleTest, TestPartiallySetInputs) {
Module module(model_path_);

auto tensor = make_tensor_ptr({1.f});

EXPECT_EQ(module.set_input(tensor, 0), Error::Ok);

const auto result = module.forward();
EXPECT_NE(result.error(), Error::Ok);
}

TEST_F(ModuleTest, TestUnsetInputs) {
Module module(model_path_);

const auto result = module.forward();
EXPECT_NE(result.error(), Error::Ok);
}
Loading