Skip to content

Commit 8c599a8

Browse files
committed
Add mapping from C++ program::verification to Python (#5915)
Summary: As titled. This enables `portable_lib._load_for_executorch[_from_buffer]` to accept `Program::Verification` argument. See added test, now we can do something like: ``` from executorch.extension.pybindings.portable_lib import Verification module = load_fn( exported_program.buffer, enable_etdump=False, debug_buffer_size=0, program_verification=Verification.Minimal, ) ``` Pull Request resolved: #5915 Test Plan: See unit test Reviewed By: dbort Differential Revision: D63987538 Pulled By: larryliu0820 fbshipit-source-id: b68d8d1149e2d46b90544679707f420179e72b19
1 parent 02c1b3d commit 8c599a8

File tree

5 files changed

+121
-36
lines changed

5 files changed

+121
-36
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
Verification, # noqa: F401
4849
)
4950

5051
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`

extension/pybindings/pybindings.cpp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ class Module final {
168168
explicit Module(
169169
std::unique_ptr<DataLoader> loader,
170170
std::unique_ptr<ETDumpGen> tracer = nullptr,
171-
size_t debug_buffer_size = 0)
171+
size_t debug_buffer_size = 0,
172+
Program::Verification program_verification =
173+
Program::Verification::InternalConsistency)
172174
: loader_(std::move(loader)),
173175
event_tracer_(std::move(tracer)),
174176
debug_buffer_size_(debug_buffer_size) {
175177
::executorch::runtime::runtime_init();
176-
Result<Program> program = Program::load(
177-
loader_.get(), Program::Verification::InternalConsistency);
178+
Result<Program> program =
179+
Program::load(loader_.get(), program_verification);
178180
THROW_IF_ERROR(
179181
program.error(),
180182
"loading program failed with error: 0x%" PRIx32,
@@ -386,19 +388,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
386388
const void* ptr,
387389
size_t ptr_len,
388390
bool enable_etdump,
389-
size_t debug_buffer_size) {
391+
size_t debug_buffer_size,
392+
Program::Verification program_verification) {
390393
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
391394
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
392395
return std::make_unique<Module>(
393396
std::move(loader),
394397
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
395-
debug_buffer_size);
398+
debug_buffer_size,
399+
program_verification);
396400
}
397401

398402
inline std::unique_ptr<Module> load_module_from_file(
399403
const std::string& path,
400404
bool enable_etdump,
401-
size_t debug_buffer_size) {
405+
size_t debug_buffer_size,
406+
Program::Verification program_verification) {
402407
EXECUTORCH_SCOPE_PROF("load_module_from_file");
403408

404409
Result<MmapDataLoader> res = MmapDataLoader::from(
@@ -413,7 +418,8 @@ inline std::unique_ptr<Module> load_module_from_file(
413418
return std::make_unique<Module>(
414419
std::move(loader),
415420
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
416-
debug_buffer_size);
421+
debug_buffer_size,
422+
program_verification);
417423
}
418424

419425
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
@@ -576,30 +582,41 @@ struct PyModule final {
576582
explicit PyModule(
577583
const py::bytes& buffer,
578584
bool enable_etdump,
579-
size_t debug_buffer_size = 0)
585+
size_t debug_buffer_size = 0,
586+
Program::Verification program_verification =
587+
Program::Verification::InternalConsistency)
580588
: module_(load_module_from_buffer(
581589
buffer.cast<std::string_view>().data(),
582590
py::len(buffer),
583591
enable_etdump,
584-
debug_buffer_size)) {}
592+
debug_buffer_size,
593+
program_verification)) {}
585594

586595
explicit PyModule(
587596
const void* ptr,
588597
size_t ptr_len,
589598
bool enable_etdump,
590-
size_t debug_buffer_size = 0)
599+
size_t debug_buffer_size = 0,
600+
Program::Verification program_verification =
601+
Program::Verification::InternalConsistency)
591602
: module_(load_module_from_buffer(
592603
ptr,
593604
ptr_len,
594605
enable_etdump,
595-
debug_buffer_size)) {}
606+
debug_buffer_size,
607+
program_verification)) {}
596608

597609
explicit PyModule(
598610
const std::string& path,
599611
bool enable_etdump,
600-
size_t debug_buffer_size = 0)
601-
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
602-
}
612+
size_t debug_buffer_size = 0,
613+
Program::Verification program_verification =
614+
Program::Verification::InternalConsistency)
615+
: module_(load_module_from_file(
616+
path,
617+
enable_etdump,
618+
debug_buffer_size,
619+
program_verification)) {}
603620

604621
PyModule(const PyModule&) = delete;
605622
PyModule& operator=(const PyModule&) = delete;
@@ -610,14 +627,20 @@ struct PyModule final {
610627
static std::unique_ptr<PyModule> load_from_buffer(
611628
const py::bytes& buffer,
612629
bool enable_etdump,
613-
size_t debug_buffer_size = 0) {
614-
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
630+
size_t debug_buffer_size = 0,
631+
Program::Verification program_verification =
632+
Program::Verification::InternalConsistency) {
633+
return std::make_unique<PyModule>(
634+
buffer, enable_etdump, debug_buffer_size, program_verification);
615635
}
616636
static std::unique_ptr<PyModule> load_from_file(
617637
const std::string& path,
618638
bool enable_etdump,
619-
size_t debug_buffer_size = 0) {
620-
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
639+
size_t debug_buffer_size = 0,
640+
Program::Verification program_verification =
641+
Program::Verification::InternalConsistency) {
642+
return std::make_unique<PyModule>(
643+
path, enable_etdump, debug_buffer_size, program_verification);
621644
}
622645

623646
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -934,19 +957,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
934957
// Redirects cout and cerr for function calls this guards to the python env.
935958
auto call_guard = py::
936959
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
960+
961+
// Bind the verification enum to python.
962+
py::enum_<Program::Verification>(m, "Verification")
963+
.value("Minimal", Program::Verification::Minimal)
964+
.value("InternalConsistency", Program::Verification::InternalConsistency);
965+
937966
m.def(
938967
"_load_for_executorch",
939968
PyModule::load_from_file,
940969
py::arg("path"),
941970
py::arg("enable_etdump") = false,
942971
py::arg("debug_buffer_size") = 0,
972+
py::arg("program_verification") =
973+
Program::Verification::InternalConsistency,
943974
call_guard);
944975
m.def(
945976
"_load_for_executorch_from_buffer",
946977
&PyModule::load_from_buffer,
947978
py::arg("buffer"),
948979
py::arg("enable_etdump") = false,
949980
py::arg("debug_buffer_size") = 0,
981+
py::arg("program_verification") =
982+
Program::Verification::InternalConsistency,
950983
call_guard);
951984
m.def(
952985
"_load_for_executorch_from_bundled_program",

extension/pybindings/pybindings.pyi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@
77
# pyre-strict
88
from __future__ import annotations
99

10-
from typing import Any, Dict, List, Optional, Sequence, Tuple
10+
from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple
1111

1212
from executorch.exir._warnings import experimental
1313

14+
@experimental("This API is experimental and subject to change without notice.")
15+
class Verification(Enum):
16+
"""Verification maps C++ Program::Verification to Python.
17+
18+
.. warning::
19+
20+
This API is experimental and subject to change without notice.
21+
"""
22+
23+
Minimal: ...
24+
InternalConsistency: ...
25+
1426
@experimental("This API is experimental and subject to change without notice.")
1527
class ExecuTorchModule:
1628
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
@@ -125,7 +137,10 @@ class MethodMeta:
125137

126138
@experimental("This API is experimental and subject to change without notice.")
127139
def _load_for_executorch(
128-
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
140+
path: str,
141+
enable_etdump: bool = False,
142+
debug_buffer_size: int = 0,
143+
program_verification: Verification = Verification.InternalConsistency,
129144
) -> ExecuTorchModule:
130145
"""Load an ExecuTorch Program from a file.
131146
@@ -148,7 +163,10 @@ def _load_for_executorch(
148163

149164
@experimental("This API is experimental and subject to change without notice.")
150165
def _load_for_executorch_from_buffer(
151-
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
166+
buffer: bytes,
167+
enable_etdump: bool = False,
168+
debug_buffer_size: int = 0,
169+
program_verification: Verification = Verification.InternalConsistency,
152170
) -> ExecuTorchModule:
153171
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.
154172

extension/pybindings/test/make_test.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# pyre-unsafe
88

99
import unittest
10-
from typing import Any, Callable, Tuple
10+
from types import ModuleType
11+
from typing import Any, Callable, Optional, Tuple
1112

1213
import torch
1314
from executorch.exir import ExecutorchProgramManager, to_edge
@@ -16,7 +17,7 @@
1617

1718
def make_test( # noqa: C901
1819
tester: unittest.TestCase,
19-
load_fn: Callable,
20+
runtime: ModuleType,
2021
) -> Callable[[unittest.TestCase], None]:
2122
"""
2223
Returns a function that operates as a test case within a unittest.TestCase class.
@@ -25,6 +26,7 @@ def make_test( # noqa: C901
2526
which will all have different load functions. In this case each individual test case is a
2627
subfunction of wrapper.
2728
"""
29+
load_fn: Callable = runtime._load_for_executorch_from_buffer
2830

2931
def wrapper(tester: unittest.TestCase) -> None:
3032
class ModuleAdd(torch.nn.Module):
@@ -323,6 +325,40 @@ def test_method_meta(tester) -> None:
323325
tester.assertEqual(output_tensor.nbytes(), 16)
324326
tester.assertEqual(str(output_tensor), tensor_info)
325327

328+
def test_bad_name(tester) -> None:
329+
# Create an ExecuTorch program from ModuleAdd.
330+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
331+
exported_program, inputs = create_program(ModuleAdd())
332+
333+
# Use pybindings to load and execute the program.
334+
executorch_module = load_fn(exported_program.buffer)
335+
# Invoke the callable on executorch_module instead of calling module.forward.
336+
with tester.assertRaises(RuntimeError):
337+
executorch_module.run_method("not_a_real_method", inputs)
338+
339+
def test_verification_config(tester) -> None:
340+
# Create an ExecuTorch program from ModuleAdd.
341+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
342+
exported_program, inputs = create_program(ModuleAdd())
343+
Verification = runtime.Verification
344+
345+
# Use pybindings to load and execute the program.
346+
for config in [Verification.Minimal, Verification.InternalConsistency]:
347+
executorch_module = load_fn(
348+
exported_program.buffer,
349+
enable_etdump=False,
350+
debug_buffer_size=0,
351+
program_verification=config,
352+
)
353+
354+
executorch_output = executorch_module.forward(inputs)[0]
355+
356+
# The test module adds the two inputs, so its output should be the same
357+
# as adding them directly.
358+
expected = inputs[0] + inputs[1]
359+
360+
tester.assertEqual(str(expected), str(executorch_output))
361+
326362
######### RUN TEST CASES #########
327363
test_e2e(tester)
328364
test_multiple_entry(tester)
@@ -333,5 +369,7 @@ def test_method_meta(tester) -> None:
333369
test_quantized_ops(tester)
334370
test_constant_output_not_memory_planned(tester)
335371
test_method_meta(tester)
372+
test_bad_name(tester)
373+
test_verification_config(tester)
336374

337375
return wrapper

extension/pybindings/test/test_pybindings.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@
1010

1111
kernel_mode = None # either aten mode or portable mode
1212
try:
13-
from executorch.extension.pybindings.portable_lib import (
14-
_load_for_executorch_from_buffer,
15-
)
13+
from executorch.extension.pybindings import portable_lib as runtime
1614

1715
kernel_mode = "portable"
1816
except Exception:
1917
print("can't load portable lib")
2018

21-
try:
22-
from executorch.extension.pybindings.aten_lib import ( # noqa: F811
23-
_load_for_executorch_from_buffer,
24-
)
25-
26-
assert kernel_mode is None
19+
if kernel_mode is None:
20+
try:
21+
from executorch.extension.pybindings import aten_lib as runtime # noqa: F811
2722

28-
kernel_mode = "aten"
29-
except Exception:
30-
print("can't load aten lib")
23+
kernel_mode = "aten"
24+
except Exception:
25+
print("can't load aten lib")
3126

3227
assert kernel_mode is not None
3328

@@ -37,4 +32,4 @@
3732

3833
class PybindingsTest(unittest.TestCase):
3934
def test(self):
40-
make_test(self, _load_for_executorch_from_buffer)(self)
35+
make_test(self, runtime)(self)

0 commit comments

Comments
 (0)