Skip to content

Commit 3227b42

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Simplify setting output. (#5334)
Summary: Pull Request resolved: #5334 Reviewed By: dbort Differential Revision: D62617066
1 parent 0d0b14a commit 3227b42

File tree

7 files changed

+92
-63
lines changed

7 files changed

+92
-63
lines changed

examples/qualcomm/oss_scripts/llama2/runner/runner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
187187
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
188188
"Failed to set output tensor when updating v_cache");
189189
ET_CHECK_MSG(
190-
module_->set_output_data_ptr(*kv_outputs[j], j + 1) == Error::Ok,
190+
module_->set_output(kv_outputs[j], j + 1) == Error::Ok,
191191
"Failed to set llama output data pointer");
192192
}
193193

@@ -291,7 +291,7 @@ Error Runner::generate(
291291
sizes,
292292
kv_tensors.back()->scalar_type()));
293293
ET_CHECK_MSG(
294-
module_->set_output_data_ptr(kv_outputs.back(), i + 1) == Error::Ok,
294+
module_->set_output(kv_outputs.back(), i + 1) == Error::Ok,
295295
"Failed to set output tensor for kv cache");
296296
}
297297

@@ -323,7 +323,7 @@ Error Runner::generate(
323323
sizes,
324324
kv_tensors.back()->scalar_type()));
325325
ET_CHECK_MSG(
326-
module_->set_output_data_ptr(kv_outputs.back(), output_index) ==
326+
module_->set_output(kv_outputs.back(), output_index) ==
327327
Error::Ok,
328328
"Failed to set output tensor for llama block");
329329
}
@@ -333,7 +333,7 @@ Error Runner::generate(
333333
logits_data_shape,
334334
ScalarType::Float);
335335
ET_CHECK_MSG(
336-
module_->set_output_data_ptr(affine_logits, 0) == Error::Ok,
336+
module_->set_output(affine_logits) == Error::Ok,
337337
"Failed to set output tensor for affine module - logits");
338338

339339
// Start consuming user's prompts and generating new tokens

examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ void KVCachedMemory::update_io(
427427
// k, v are placed interleaved
428428
int index = (cache_stride << 1) + (cache_group << 5) + head;
429429
ET_CHECK_MSG(
430-
modules_[shard]->set_output_data_ptr(
430+
modules_[shard]->set_out
431431
output_tensors[shard][index], index) == Error::Ok,
432432
"failed to set output tensor for module %d's %d'th output "
433433
"while updating kv_cache output tensors",
@@ -450,8 +450,7 @@ void KVCachedMemory::update_io(
450450
for (int shard = 0; shard < output_tensors.size(); shard++) {
451451
for (int index = 0; index < output_tensors[shard].size(); index++) {
452452
ET_CHECK_MSG(
453-
modules_[shard]->set_output_data_ptr(
454-
output_tensors[shard][index], index) == Error::Ok,
453+
modules_[shard]->set_output(output_tensors[shard][index], index) == Error::Ok,
455454
"failed to set output tensor for module %d's %d'th output "
456455
"while updating kv_cache output tensors",
457456
shard,

examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ Error Runner::generate(
177177
output_tensors.emplace_back(io_mem_->get_output_tensors(i));
178178
for (size_t j = 0; j < output_tensors[i].size(); ++j) {
179179
ET_CHECK_MSG(
180-
modules_[i]->set_output_data_ptr(output_tensors[i][j], j) ==
181-
Error::Ok,
180+
modules_[i]->set_output(output_tensors[i][j], j) == Error::Ok,
182181
"failed to set output tensor for module %d's %zu'th output",
183182
i,
184183
j);

examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ Error Runner::generate(std::string prompt) {
373373
uncond_emb_vec.data(),
374374
{1, 77, 1024},
375375
encoder_method_meta.output_tensor_meta(0)->scalar_type());
376-
modules_[0]->set_output_data_ptr(cond_emb_tensor, 0);
376+
modules_[0]->set_output(cond_emb_tensor);
377377
long encoder_start = util::time_in_ms();
378378
auto cond_res = modules_[0]->forward(cond_tokens_tensor);
379379
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
380-
modules_[0]->set_output_data_ptr(uncond_emb_tensor, 0);
380+
modules_[0]->set_output(uncond_emb_tensor);
381381
encoder_start = util::time_in_ms();
382382
auto uncond_res = modules_[0]->forward(uncond_tokens_tensor);
383383
stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start);
@@ -462,13 +462,13 @@ Error Runner::generate(std::string prompt) {
462462

463463
stats_.unet_aggregate_post_processing_time +=
464464
(util::time_in_ms() - start_post_process);
465-
modules_[1]->set_output_data_ptr(noise_pred_text_tensor, 0);
465+
modules_[1]->set_output(noise_pred_text_tensor);
466466
long start_unet_execution = util::time_in_ms();
467467
auto cond_res = modules_[1]->forward(
468468
{latent_tensor, time_emb_tensors[step_index], cond_emb_tensor});
469469
stats_.unet_aggregate_execution_time +=
470470
(util::time_in_ms() - start_unet_execution);
471-
modules_[1]->set_output_data_ptr(noise_pred_uncond_tensor, 0);
471+
modules_[1]->set_output(noise_pred_uncond_tensor);
472472
start_unet_execution = util::time_in_ms();
473473
auto uncond_res = modules_[1]->forward(
474474
{latent_tensor,
@@ -519,7 +519,7 @@ Error Runner::generate(std::string prompt) {
519519

520520
quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_);
521521

522-
modules_[2]->set_output_data_ptr(output_tensor, 0);
522+
modules_[2]->set_output(output_tensor);
523523
long start_vae_execution = util::time_in_ms();
524524
auto vae_res = modules_[2]->forward(vae_input_tensor);
525525
stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution);

extension/module/module.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,13 @@ runtime::Result<runtime::MethodMeta> Module::method_meta(
167167

168168
runtime::Result<std::vector<runtime::EValue>> Module::execute(
169169
const std::string& method_name,
170-
const std::vector<runtime::EValue>& input) {
170+
const std::vector<runtime::EValue>& input_values) {
171171
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
172172
auto& method = methods_.at(method_name).method;
173173

174-
ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
175-
exec_aten::ArrayRef<runtime::EValue>(input.data(), input.size())));
174+
ET_CHECK_OK_OR_RETURN_ERROR(
175+
method->set_inputs(exec_aten::ArrayRef<runtime::EValue>(
176+
input_values.data(), input_values.size())));
176177
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
177178

178179
const auto outputs_size = method->outputs_size();
@@ -183,13 +184,18 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
183184
return outputs;
184185
}
185186

186-
runtime::Error Module::set_output_data_ptr(
187+
runtime::Error Module::set_output(
188+
const std::string& method_name,
187189
runtime::EValue output_value,
188-
size_t output_index,
189-
const std::string& method_name) {
190+
size_t output_index) {
190191
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
191-
auto& output_tensor = output_value.toTensor();
192192
auto& method = methods_.at(method_name).method;
193+
ET_CHECK_OR_RETURN_ERROR(
194+
output_value.isTensor(),
195+
InvalidArgument,
196+
"output type: %zu is not tensor",
197+
(size_t)output_value.tag);
198+
const auto& output_tensor = output_value.toTensor();
193199
return method->set_output_data_ptr(
194200
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
195201
}

extension/module/module.h

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,35 @@ class Module {
165165
const std::string& method_name);
166166

167167
/**
168-
* Execute a specific method with the given input and retrieve output.
169-
* Loads the program and method before executing if needed.
168+
* Execute a specific method with the given input values and retrieve the
169+
* output values. Loads the program and method before executing if needed.
170170
*
171171
* @param[in] method_name The name of the method to execute.
172-
* @param[in] input A vector of input values to be passed to the method.
172+
* @param[in] input_values A vector of input values to be passed to the
173+
* method.
173174
*
174175
* @returns A Result object containing either a vector of output values
175176
* from the method or an error to indicate failure.
176177
*/
177178
ET_NODISCARD
178179
runtime::Result<std::vector<runtime::EValue>> execute(
179180
const std::string& method_name,
180-
const std::vector<runtime::EValue>& input);
181+
const std::vector<runtime::EValue>& input_values);
181182

182183
/**
183184
* Execute a specific method with a single input value.
184185
* Loads the program and method before executing if needed.
185186
*
186187
* @param[in] method_name The name of the method to execute.
187-
* @param[in] input A value to be passed to the method.
188+
* @param[in] input_value A value to be passed to the method.
188189
*
189190
* @returns A Result object containing either a vector of output values
190191
* from the method or an error to indicate failure.
191192
*/
192193
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
193194
const std::string& method_name,
194-
const runtime::EValue& input) {
195-
return execute(method_name, std::vector<runtime::EValue>{input});
195+
const runtime::EValue& input_value) {
196+
return execute(method_name, std::vector<runtime::EValue>{input_value});
196197
}
197198

198199
/**
@@ -210,19 +211,20 @@ class Module {
210211
}
211212

212213
/**
213-
* Retrieve the output value of a specific method with the given input.
214+
* Retrieve the output value of a specific method with the given input values.
214215
* Loads the program and method before execution if needed.
215216
*
216217
* @param[in] method_name The name of the method to execute.
217-
* @param[in] input A vector of input values to be passed to the method.
218+
* @param[in] input_values A vector of input values to be passed to the
219+
* method.
218220
*
219221
* @returns A Result object containing either the first output value from the
220222
* method or an error to indicate failure.
221223
*/
222224
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
223225
const std::string& method_name,
224-
const std::vector<runtime::EValue>& input) {
225-
auto result = ET_UNWRAP(execute(method_name, input));
226+
const std::vector<runtime::EValue>& input_values) {
227+
auto result = ET_UNWRAP(execute(method_name, input_values));
226228
if (result.empty()) {
227229
return runtime::Error::InvalidArgument;
228230
}
@@ -234,15 +236,15 @@ class Module {
234236
* Loads the program and method before execution if needed.
235237
*
236238
* @param[in] method_name The name of the method to execute.
237-
* @param[in] input A value to be passed to the method.
239+
* @param[in] input_value A value to be passed to the method.
238240
*
239241
* @returns A Result object containing either the first output value from the
240242
* method or an error to indicate failure.
241243
*/
242244
ET_NODISCARD inline runtime::Result<runtime::EValue> get(
243245
const std::string& method_name,
244-
const runtime::EValue& input) {
245-
return get(method_name, std::vector<runtime::EValue>{input});
246+
const runtime::EValue& input_value) {
247+
return get(method_name, std::vector<runtime::EValue>{input_value});
246248
}
247249

248250
/**
@@ -260,31 +262,31 @@ class Module {
260262
}
261263

262264
/**
263-
* Execute the 'forward' method with the given input and retrieve output.
264-
* Loads the program and method before executing if needed.
265+
* Execute the 'forward' method with the given input values and retrieve the
266+
* output values. Loads the program and method before executing if needed.
265267
*
266-
* @param[in] input A vector of input values for the 'forward' method.
268+
* @param[in] input_values A vector of input values for the 'forward' method.
267269
*
268270
* @returns A Result object containing either a vector of output values
269271
* from the 'forward' method or an error to indicate failure.
270272
*/
271273
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
272-
const std::vector<runtime::EValue>& input) {
273-
return execute("forward", input);
274+
const std::vector<runtime::EValue>& input_values) {
275+
return execute("forward", input_values);
274276
}
275277

276278
/**
277279
* Execute the 'forward' method with a single value.
278280
* Loads the program and method before executing if needed.
279281
*
280-
* @param[in] input A value for the 'forward' method.
282+
* @param[in] input_value A value for the 'forward' method.
281283
*
282284
* @returns A Result object containing either a vector of output values
283285
* from the 'forward' method or an error to indicate failure.
284286
*/
285287
ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
286-
const runtime::EValue& input) {
287-
return forward(std::vector<runtime::EValue>{input});
288+
const runtime::EValue& input_value) {
289+
return forward(std::vector<runtime::EValue>{input_value});
288290
}
289291

290292
/**
@@ -298,6 +300,42 @@ class Module {
298300
return forward(std::vector<runtime::EValue>{});
299301
}
300302

303+
/**
304+
* Sets the output tensor for a specific method.
305+
*
306+
* @param[in] method_name The name of the method.
307+
* @param[in] output_value The EValue containing the Tensor to set as the
308+
* method output.
309+
* @param[in] output_index Zero-based index of the output to set.
310+
*
311+
* @returns An Error to indicate success or failure.
312+
*
313+
* @note Only Tensor outputs are currently supported for setting.
314+
*/
315+
ET_NODISCARD
316+
runtime::Error set_output(
317+
const std::string& method_name,
318+
runtime::EValue output_value,
319+
size_t output_index = 0);
320+
321+
/**
322+
* Sets the output tensor for the "forward" method.
323+
*
324+
* @param[in] output_value The EValue containing the Tensor to set as the
325+
* method output.
326+
* @param[in] output_index Zero-based index of the output to set.
327+
*
328+
* @returns An Error to indicate success or failure.
329+
*
330+
* @note Only Tensor outputs are currently supported for setting.
331+
*/
332+
ET_NODISCARD
333+
inline runtime::Error set_output(
334+
runtime::EValue output_value,
335+
size_t output_index = 0) {
336+
return set_output("forward", output_value, output_index);
337+
}
338+
301339
/**
302340
* Retrieves the EventTracer instance being used by the Module.
303341
* EventTracer is used for tracking and logging events during the execution
@@ -310,19 +348,6 @@ class Module {
310348
return event_tracer_.get();
311349
}
312350

313-
/**
314-
* Set output data pointer for forward method.
315-
*
316-
* @param[in] output_value A Tensor for the output of 'forward' method.
317-
* @param[in] output_index Index of the output in 'forward' method.
318-
*
319-
* @returns An Error to indicate success or failure of the loading process.
320-
*/
321-
runtime::Error set_output_data_ptr(
322-
runtime::EValue output_value,
323-
size_t output_index,
324-
const std::string& method_name = "forward");
325-
326351
private:
327352
struct MethodHolder {
328353
std::vector<std::vector<uint8_t>> planned_buffers;

extension/module/test/module_test.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
122122

123123
TEST_F(ModuleTest, TestExecute) {
124124
Module module(model_path_);
125-
auto tensor = make_tensor_ptr({1}, {1});
125+
auto tensor = make_tensor_ptr({1});
126126

127127
const auto result = module.execute("forward", {tensor, tensor});
128128
EXPECT_TRUE(result.ok());
@@ -141,7 +141,7 @@ TEST_F(ModuleTest, TestExecutePreload) {
141141
const auto error = module.load();
142142
EXPECT_EQ(error, Error::Ok);
143143

144-
auto tensor = make_tensor_ptr({1}, {1});
144+
auto tensor = make_tensor_ptr({1});
145145

146146
const auto result = module.execute("forward", {tensor, tensor});
147147
EXPECT_TRUE(result.ok());
@@ -157,7 +157,7 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
157157
const auto error = module.load_method("forward");
158158
EXPECT_EQ(error, Error::Ok);
159159

160-
auto tensor = make_tensor_ptr({1}, {1});
160+
auto tensor = make_tensor_ptr({1});
161161

162162
const auto result = module.execute("forward", {tensor, tensor});
163163
EXPECT_TRUE(result.ok());
@@ -176,7 +176,7 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
176176
const auto load_method_error = module.load_method("forward");
177177
EXPECT_EQ(load_method_error, Error::Ok);
178178

179-
auto tensor = make_tensor_ptr({1}, {1});
179+
auto tensor = make_tensor_ptr({1});
180180

181181
const auto result = module.execute("forward", {tensor, tensor});
182182
EXPECT_TRUE(result.ok());
@@ -205,7 +205,7 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
205205
TEST_F(ModuleTest, TestGet) {
206206
Module module(model_path_);
207207

208-
auto tensor = make_tensor_ptr({1}, {1});
208+
auto tensor = make_tensor_ptr({1});
209209

210210
const auto result = module.get("forward", {tensor, tensor});
211211
EXPECT_TRUE(result.ok());
@@ -280,7 +280,7 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
280280
EXPECT_EQ(load_error, Error::Ok);
281281
EXPECT_TRUE(module1->is_loaded());
282282

283-
auto tensor = make_tensor_ptr({1}, {1});
283+
auto tensor = make_tensor_ptr({1});
284284

285285
const auto result1 = module1->execute("forward", {tensor, tensor});
286286
EXPECT_TRUE(result1.ok());
@@ -325,7 +325,7 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
325325

326326
EXPECT_EQ(module.program(), shared_program);
327327

328-
auto tensor = make_tensor_ptr({1}, {1});
328+
auto tensor = make_tensor_ptr({1});
329329

330330
const auto result = module.execute("forward", {tensor, tensor});
331331
EXPECT_TRUE(result.ok());

0 commit comments

Comments
 (0)