Skip to content

Commit f560682

Browse files
authored
Allow single EValue to be passed to Module execute.
Differential Revision: D61799496 Pull Request resolved: #4907
1 parent 6feb639 commit f560682

File tree

7 files changed

+74
-23
lines changed

7 files changed

+74
-23
lines changed

docs/source/extension-module.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Tensor::SizesType sizes[] = {1, 3, 256, 256};
2222
TensorImpl tensor(ScalarType::Float, std::size(sizes), sizes, input);
2323

2424
// Perform an inference.
25-
const auto result = module.forward({EValue(Tensor(&tensor))});
25+
const auto result = module.forward(Tensor(&tensor));
2626

2727
// Check for success or failure.
2828
if (result.ok()) {
@@ -105,13 +105,13 @@ Note: `method_meta()` will try to force-load the `Method` when called for the fi
105105
Assuming that the `Program`'s method names and their input format is known ahead of time, we rarely need to query for those and can run the methods directly by name using the `execute()` function:
106106

107107
```cpp
108-
const auto result = module.execute("forward", {EValue(Tensor(&tensor))});
108+
const auto result = module.execute("forward", Tensor(&tensor));
109109
```
110110

111111
Which can also be simplified for the standard `forward()` method name as:
112112

113113
```cpp
114-
const auto result = module.forward({EValue(Tensor(&tensor))});
114+
const auto result = module.forward(Tensor(&tensor));
115115
```
116116

117117
Note: `execute()` or `forward()` will try to force load the `Program` and the `Method` when called for the first time. Therefore, the first inference will take more time than subsequent ones as it loads the model lazily and prepares it for execution unless the `Program` or `Method` was loaded explicitly earlier using the corresponding functions.

examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo/Sources/MobileNet/MobileNetClassifier.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ - (BOOL)classifyWithInput:(float*)input
3535
error:(NSError**)error {
3636
int32_t sizes[] = {1, kChannels, kSize, kSize};
3737
TensorImpl inputTensor(ScalarType::Float, std::size(sizes), sizes, input);
38-
const auto result = _module->forward({EValue(Tensor(&inputTensor))});
38+
const auto result = _module->forward(Tensor(&inputTensor));
3939

4040
if (!result.ok()) {
4141
if (error) {

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class LlavaImagePrefiller : public ImagePrefiller {
3030
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
3131
// Run image encoder
3232
std::vector<EValue> image_encoder_outputs = ET_UNWRAP(module_->execute(
33-
kImageEncoderMethod, {managed_images.get_aliasing_tensor()}));
33+
kImageEncoderMethod, managed_images.get_aliasing_tensor()));
3434

3535
// inputs:[start_pos, embeds]
3636
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class LlavaTextDecoderRunner : public TextDecoderRunner {
2727

2828
// run token embedding
2929
std::vector<EValue> token_embedding_outputs =
30-
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, {tokens}));
30+
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
3131

3232
// run text model
3333
std::vector<EValue> outputs_res = ET_UNWRAP(module_->execute(

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ ::executorch::runtime::Result<exec_aten::Tensor> TextDecoderRunner::step(
6060
(void)managed_start_pos; // unused
6161

6262
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
63-
outputs_res = module_->forward({tokens});
63+
outputs_res = module_->forward(tokens);
6464
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
6565
ET_CHECK_MSG(
6666
outputs_res.get().size() == 1,

extension/module/module.h

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,25 @@ class Module final {
181181
const std::string& method_name,
182182
const std::vector<::executorch::runtime::EValue>& input);
183183

184+
/**
185+
* Execute a specific method with a single input value.
186+
* Loads the program and method before executing if needed.
187+
*
188+
* @param[in] method_name The name of the method to execute.
189+
* @param[in] input A value to be passed to the method.
190+
*
191+
* @returns A Result object containing either a vector of output values
192+
* from the method or an error to indicate failure.
193+
*/
194+
ET_NODISCARD
195+
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
196+
execute(
197+
const std::string& method_name,
198+
const ::executorch::runtime::EValue& input) {
199+
return execute(
200+
method_name, std::vector<::executorch::runtime::EValue>{input});
201+
}
202+
184203
/**
185204
* Execute a specific method without any input values.
186205
* Loads the program and method before executing if needed.
@@ -193,7 +212,7 @@ class Module final {
193212
ET_NODISCARD
194213
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
195214
execute(const std::string& method_name) {
196-
return execute(method_name, {});
215+
return execute(method_name, std::vector<::executorch::runtime::EValue>{});
197216
}
198217

199218
/**
@@ -217,6 +236,23 @@ class Module final {
217236
return result[0];
218237
}
219238

239+
/**
240+
* Retrieve the output value of a specific method with a single input value.
241+
* Loads the program and method before execution if needed.
242+
*
243+
* @param[in] method_name The name of the method to execute.
244+
* @param[in] input A value to be passed to the method.
245+
*
246+
* @returns A Result object containing either the first output value from the
247+
* method or an error to indicate failure.
248+
*/
249+
ET_NODISCARD
250+
::executorch::runtime::Result<::executorch::runtime::EValue> get(
251+
const std::string& method_name,
252+
const ::executorch::runtime::EValue& input) {
253+
return get(method_name, std::vector<::executorch::runtime::EValue>{input});
254+
}
255+
220256
/**
221257
* Retrieve the output value of a specific method without any input values.
222258
* Loads the program and method before execution if needed.
@@ -229,7 +265,7 @@ class Module final {
229265
ET_NODISCARD
230266
::executorch::runtime::Result<::executorch::runtime::EValue> get(
231267
const std::string& method_name) {
232-
return get(method_name, {});
268+
return get(method_name, std::vector<::executorch::runtime::EValue>{});
233269
}
234270

235271
/**
@@ -247,6 +283,21 @@ class Module final {
247283
return execute("forward", input);
248284
}
249285

286+
/**
287+
* Execute the 'forward' method with a single value.
288+
* Loads the program and method before executing if needed.
289+
*
290+
* @param[in] input A value for the 'forward' method.
291+
*
292+
* @returns A Result object containing either a vector of output values
293+
* from the 'forward' method or an error to indicate failure.
294+
*/
295+
ET_NODISCARD
296+
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
297+
forward(const ::executorch::runtime::EValue& input) {
298+
return forward(std::vector<::executorch::runtime::EValue>{input});
299+
}
300+
250301
/**
251302
* Execute the 'forward' method without any input values.
252303
* Loads the program and method before executing if needed.
@@ -257,7 +308,7 @@ class Module final {
257308
ET_NODISCARD
258309
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
259310
forward() {
260-
return forward({});
311+
return forward(std::vector<::executorch::runtime::EValue>{});
261312
}
262313

263314
/**

extension/module/test/module_test.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ TEST_F(ModuleTest, TestExecute) {
129129
TensorImpl tensor(
130130
ScalarType::Float, sizes.size(), sizes.data(), input.data());
131131

132-
const auto result = module.execute("forward", {EValue(Tensor(&tensor))});
132+
const auto result = module.execute("forward", Tensor(&tensor));
133133
EXPECT_TRUE(result.ok());
134134
EXPECT_TRUE(module.is_loaded());
135135
EXPECT_TRUE(module.is_method_loaded("forward"));
@@ -150,7 +150,7 @@ TEST_F(ModuleTest, TestExecutePreload) {
150150
TensorImpl tensor(
151151
ScalarType::Float, sizes.size(), sizes.data(), input.data());
152152

153-
const auto result = module.execute("forward", {EValue(Tensor(&tensor))});
153+
const auto result = module.execute("forward", Tensor(&tensor));
154154
EXPECT_TRUE(result.ok());
155155

156156
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -169,7 +169,7 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
169169
TensorImpl tensor(
170170
ScalarType::Float, sizes.size(), sizes.data(), input.data());
171171

172-
const auto result = module.execute("forward", {EValue(Tensor(&tensor))});
172+
const auto result = module.execute("forward", Tensor(&tensor));
173173
EXPECT_TRUE(result.ok());
174174

175175
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -191,7 +191,7 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
191191
TensorImpl tensor(
192192
ScalarType::Float, sizes.size(), sizes.data(), input.data());
193193

194-
const auto result = module.execute("forward", {EValue(Tensor(&tensor))});
194+
const auto result = module.execute("forward", Tensor(&tensor));
195195
EXPECT_TRUE(result.ok());
196196

197197
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -223,7 +223,7 @@ TEST_F(ModuleTest, TestGet) {
223223
TensorImpl tensor(
224224
ScalarType::Float, sizes.size(), sizes.data(), input.data());
225225

226-
const auto result = module.get("forward", {EValue(Tensor(&tensor))});
226+
const auto result = module.get("forward", Tensor(&tensor));
227227

228228
EXPECT_TRUE(result.ok());
229229
const auto data = result->toTensor().const_data_ptr<float>();
@@ -237,7 +237,7 @@ TEST_F(ModuleTest, TestForward) {
237237
std::array<int32_t, 2> sizes{1, 2};
238238
TensorImpl tensor(
239239
ScalarType::Float, sizes.size(), sizes.data(), input.data());
240-
const auto result = module->forward({EValue(Tensor(&tensor))});
240+
const auto result = module->forward(Tensor(&tensor));
241241
EXPECT_TRUE(result.ok());
242242

243243
const auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -247,7 +247,7 @@ TEST_F(ModuleTest, TestForward) {
247247
std::array<float, 2> input2{2, 3};
248248
TensorImpl tensor2(
249249
ScalarType::Float, sizes.size(), sizes.data(), input2.data());
250-
const auto result2 = module->forward({EValue(Tensor(&tensor2))});
250+
const auto result2 = module->forward(Tensor(&tensor2));
251251
EXPECT_TRUE(result2.ok());
252252

253253
const auto data2 = result->at(0).toTensor().const_data_ptr<float>();
@@ -258,7 +258,7 @@ TEST_F(ModuleTest, TestForward) {
258258
TEST_F(ModuleTest, TestForwardWithInvalidInputs) {
259259
Module module(model_path_);
260260

261-
const auto result = module.forward({EValue()});
261+
const auto result = module.forward(EValue());
262262

263263
EXPECT_FALSE(result.ok());
264264
}
@@ -308,18 +308,18 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
308308
TensorImpl tensor(
309309
ScalarType::Float, sizes.size(), sizes.data(), input.data());
310310

311-
auto result1 = module1->execute("forward", {EValue(Tensor(&tensor))});
311+
auto result1 = module1->execute("forward", Tensor(&tensor));
312312
EXPECT_TRUE(result1.ok());
313313

314314
auto module2 = std::make_unique<Module>(module1->program());
315315

316-
auto result2 = module2->execute("forward", {EValue(Tensor(&tensor))});
316+
auto result2 = module2->execute("forward", Tensor(&tensor));
317317
EXPECT_TRUE(result2.ok());
318318

319319
module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
320320
EXPECT_FALSE(module1->is_loaded());
321321

322-
auto result3 = module2->execute("forward", {EValue(Tensor(&tensor))});
322+
auto result3 = module2->execute("forward", Tensor(&tensor));
323323
EXPECT_TRUE(result3.ok());
324324
}
325325

@@ -356,7 +356,7 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
356356
TensorImpl tensor(
357357
ScalarType::Float, sizes.size(), sizes.data(), input.data());
358358

359-
auto result = module.execute("forward", {EValue(Tensor(&tensor))});
359+
auto result = module.execute("forward", Tensor(&tensor));
360360
EXPECT_TRUE(result.ok());
361361

362362
auto data = result->at(0).toTensor().const_data_ptr<float>();
@@ -385,7 +385,7 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
385385
TensorImpl tensor(
386386
ScalarType::Float, sizes.size(), sizes.data(), (void*)input.data());
387387

388-
const auto result = module.forward({EValue(Tensor(&tensor))});
388+
const auto result = module.forward(Tensor(&tensor));
389389
EXPECT_TRUE(result.ok());
390390

391391
const auto data = result->at(0).toTensor().const_data_ptr<float>();

0 commit comments

Comments
 (0)