File tree Expand file tree Collapse file tree 2 files changed +50
-0
lines changed Expand file tree Collapse file tree 2 files changed +50
-0
lines changed Original file line number Diff line number Diff line change @@ -158,6 +158,41 @@ class Module final {
158
158
return execute (method_name, {});
159
159
}
160
160
161
+ /* *
162
+ * Retrieve the output value of a specific method with the given input.
163
+ * Loads the program and method before execution if needed.
164
+ *
165
+ * @param[in] method_name The name of the method to execute.
166
+ * @param[in] input A vector of input values to be passed to the method.
167
+ *
168
+ * @returns A Result object containing either the first output value from the
169
+ * method or an error to indicate failure.
170
+ */
171
+ __ET_NODISCARD
172
+ Result<EValue> get (
173
+ const std::string& method_name,
174
+ const std::vector<EValue>& input) {
175
+ auto result = ET_UNWRAP (execute (method_name, input));
176
+ if (result.empty ()) {
177
+ return Error::InvalidArgument;
178
+ }
179
+ return result[0 ];
180
+ }
181
+
182
+ /* *
183
+ * Retrieve the output value of a specific method without any input values.
184
+ * Loads the program and method before execution if needed.
185
+ *
186
+ * @param[in] method_name The name of the method to execute.
187
+ *
188
+ * @returns A Result object containing either the first output value from the
189
+ * method or an error to indicate failure.
190
+ */
191
+ __ET_NODISCARD
192
+ Result<EValue> get (const std::string& method_name) {
193
+ return get (method_name, {});
194
+ }
195
+
161
196
/* *
162
197
* Execute the 'forward' method with the given input and retrieve output.
163
198
* Loads the program and method before executing if needed.
Original file line number Diff line number Diff line change @@ -202,6 +202,21 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
202
202
EXPECT_FALSE (result.ok ());
203
203
}
204
204
205
+ TEST_F (ModuleTest, TestGet) {
206
+ Module module (std::getenv (" RESOURCES_PATH" ) + std::string (" /model.pte" ));
207
+
208
+ std::array<float , 2 > input{1 , 2 };
209
+ std::array<int32_t , 2 > sizes{1 , 2 };
210
+ TensorImpl tensor (
211
+ ScalarType::Float, sizes.size (), sizes.data (), input.data ());
212
+
213
+ const auto result = module .get (" forward" , {EValue (Tensor (&tensor))});
214
+
215
+ EXPECT_TRUE (result.ok ());
216
+ const auto data = result->toTensor ().const_data_ptr <float >();
217
+ EXPECT_NEAR (data[0 ], 1.5 , 1e-5 );
218
+ }
219
+
205
220
TEST_F (ModuleTest, TestForward) {
206
221
auto module = std::make_unique<Module>(
207
222
std::getenv (" RESOURCES_PATH" ) + std::string (" /model.pte" ));
You can’t perform that action at this time.
0 commit comments