Skip to content

Commit 6982c03

Browse files
authored
Add the get method to Module to return a single EValue.
Differential Revision: D61170093 Pull Request resolved: #4686
1 parent 85b7869 commit 6982c03

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

extension/module/module.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,41 @@ class Module final {
158158
return execute(method_name, {});
159159
}
160160

161+
/**
162+
* Retrieve the output value of a specific method with the given input.
163+
* Loads the program and method before execution if needed.
164+
*
165+
* @param[in] method_name The name of the method to execute.
166+
* @param[in] input A vector of input values to be passed to the method.
167+
*
168+
* @returns A Result object containing either the first output value from the
169+
* method or an error to indicate failure.
170+
*/
171+
__ET_NODISCARD
172+
Result<EValue> get(
173+
const std::string& method_name,
174+
const std::vector<EValue>& input) {
175+
auto result = ET_UNWRAP(execute(method_name, input));
176+
if (result.empty()) {
177+
return Error::InvalidArgument;
178+
}
179+
return result[0];
180+
}
181+
182+
/**
183+
* Retrieve the output value of a specific method without any input values.
184+
* Loads the program and method before execution if needed.
185+
*
186+
* @param[in] method_name The name of the method to execute.
187+
*
188+
* @returns A Result object containing either the first output value from the
189+
* method or an error to indicate failure.
190+
*/
191+
__ET_NODISCARD
192+
Result<EValue> get(const std::string& method_name) {
193+
return get(method_name, {});
194+
}
195+
161196
/**
162197
* Execute the 'forward' method with the given input and retrieve output.
163198
* Loads the program and method before executing if needed.

extension/module/test/module_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,21 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
202202
EXPECT_FALSE(result.ok());
203203
}
204204

205+
TEST_F(ModuleTest, TestGet) {
206+
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
207+
208+
std::array<float, 2> input{1, 2};
209+
std::array<int32_t, 2> sizes{1, 2};
210+
TensorImpl tensor(
211+
ScalarType::Float, sizes.size(), sizes.data(), input.data());
212+
213+
const auto result = module.get("forward", {EValue(Tensor(&tensor))});
214+
215+
EXPECT_TRUE(result.ok());
216+
const auto data = result->toTensor().const_data_ptr<float>();
217+
EXPECT_NEAR(data[0], 1.5, 1e-5);
218+
}
219+
205220
TEST_F(ModuleTest, TestForward) {
206221
auto module = std::make_unique<Module>(
207222
std::getenv("RESOURCES_PATH") + std::string("/model.pte"));

0 commit comments

Comments
 (0)