Skip to content

Commit 9720715

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
throw instead of segfault with invalid args in pybindings (#5726)
Summary: Pull Request resolved: #5726 Fixed some Error() call sites and added better checking on the method name Reviewed By: larryliu0820, dbort Differential Revision: D63547425 fbshipit-source-id: 69e32f90cbc2607b75df5186b0289fe385ce95e3
1 parent e2f1aca commit 9720715

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void setup_output_storage(
140140
const std::vector<Span<uint8_t>>& output_storages) {
141141
if (output_storages.size() != method.outputs_size()) {
142142
THROW_IF_ERROR(
143-
Error(),
143+
Error::InvalidArgument,
144144
"number of output storages %zu does not match number of outputs %zu",
145145
output_storages.size(),
146146
method.outputs_size());
@@ -249,10 +249,10 @@ class Module final {
249249
const std::vector<EValue>& args,
250250
const std::optional<std::vector<Span<uint8_t>>>& output_storages =
251251
std::nullopt) {
252-
auto& method = methods_[method_name];
252+
auto& method = get_method(method_name);
253253
exec_aten::ArrayRef<EValue> input_evalue_list(args.data(), args.size());
254254

255-
Error set_inputs_status = method->set_inputs(input_evalue_list);
255+
Error set_inputs_status = method.set_inputs(input_evalue_list);
256256
THROW_IF_ERROR(
257257
set_inputs_status,
258258
"method->set_inputs() for method '%s' failed with error 0x%" PRIx32,
@@ -273,9 +273,9 @@ class Module final {
273273
c10::autograd_dispatch_keyset);
274274
#endif
275275
if (output_storages) {
276-
setup_output_storage(*method, *output_storages);
276+
setup_output_storage(method, *output_storages);
277277
}
278-
Error execute_status = method->execute();
278+
Error execute_status = method.execute();
279279
THROW_IF_ERROR(
280280
execute_status,
281281
"method->execute() failed with error 0x%" PRIx32,
@@ -302,7 +302,9 @@ class Module final {
302302
Method& get_method(const std::string& method_name) {
303303
if (methods_.count(method_name) == 0) {
304304
THROW_IF_ERROR(
305-
Error(), "no such method in program: %s", method_name.c_str());
305+
Error::InvalidArgument,
306+
"no such method in program: %s",
307+
method_name.c_str());
306308
}
307309
return *methods_[method_name].get();
308310
}

extension/pybindings/test/make_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,17 @@ def test_method_meta(tester) -> None:
341341
tester.assertEqual(output_tensor.nbytes(), 16)
342342
tester.assertEqual(str(output_tensor), tensor_info)
343343

344+
def test_bad_name(tester) -> None:
345+
# Create an ExecuTorch program from ModuleAdd.
346+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
347+
exported_program, inputs = create_program(ModuleAdd())
348+
349+
# Use pybindings to load and execute the program.
350+
executorch_module = load_fn(exported_program.buffer)
351+
# Invoke the callable on executorch_module instead of calling module.forward.
352+
with tester.assertRaises(RuntimeError):
353+
executorch_module.run_method("not_a_real_method", inputs)
354+
344355
######### RUN TEST CASES #########
345356
test_e2e(tester)
346357
test_multiple_entry(tester)
@@ -351,5 +362,6 @@ def test_method_meta(tester) -> None:
351362
test_quantized_ops(tester)
352363
test_constant_output_not_memory_planned(tester)
353364
test_method_meta(tester)
365+
test_bad_name(tester)
354366

355367
return wrapper

0 commit comments

Comments
 (0)