Skip to content

Commit ab61a06

Browse files
committed
Add mapping from C++ program::verification to Python (#5915)
Summary: As titled. This enables `portable_lib._load_for_executorch[_from_buffer]` to accept `Program::Verification` argument. See added test, now we can do something like: ``` from executorch.extension.pybindings.portable_lib import Verification module = load_fn( exported_program.buffer, enable_etdump=False, debug_buffer_size=0, program_verification=Verification.Minimal, ) ``` Pull Request resolved: #5915 Test Plan: See unit test Reviewed By: dbort Differential Revision: D63987538 Pulled By: larryliu0820 fbshipit-source-id: b68d8d1149e2d46b90544679707f420179e72b19
1 parent eca44f0 commit ab61a06

File tree

5 files changed

+110
-36
lines changed

5 files changed

+110
-36
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
Verification, # noqa: F401
4849
)
4950

5051
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`

extension/pybindings/pybindings.cpp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,15 @@ class Module final {
157157
explicit Module(
158158
std::unique_ptr<DataLoader> loader,
159159
std::unique_ptr<ETDumpGen> tracer = nullptr,
160-
size_t debug_buffer_size = 0)
160+
size_t debug_buffer_size = 0,
161+
Program::Verification program_verification =
162+
Program::Verification::InternalConsistency)
161163
: loader_(std::move(loader)),
162164
event_tracer_(std::move(tracer)),
163165
debug_buffer_size_(debug_buffer_size) {
164166
::executorch::runtime::runtime_init();
165-
Result<Program> program = Program::load(
166-
loader_.get(), Program::Verification::InternalConsistency);
167+
Result<Program> program =
168+
Program::load(loader_.get(), program_verification);
167169
THROW_IF_ERROR(
168170
program.error(),
169171
"loading program failed with error: 0x%" PRIx32,
@@ -375,19 +377,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
375377
const void* ptr,
376378
size_t ptr_len,
377379
bool enable_etdump,
378-
size_t debug_buffer_size) {
380+
size_t debug_buffer_size,
381+
Program::Verification program_verification) {
379382
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
380383
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
381384
return std::make_unique<Module>(
382385
std::move(loader),
383386
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
384-
debug_buffer_size);
387+
debug_buffer_size,
388+
program_verification);
385389
}
386390

387391
inline std::unique_ptr<Module> load_module_from_file(
388392
const std::string& path,
389393
bool enable_etdump,
390-
size_t debug_buffer_size) {
394+
size_t debug_buffer_size,
395+
Program::Verification program_verification) {
391396
EXECUTORCH_SCOPE_PROF("load_module_from_file");
392397

393398
Result<MmapDataLoader> res = MmapDataLoader::from(
@@ -402,7 +407,8 @@ inline std::unique_ptr<Module> load_module_from_file(
402407
return std::make_unique<Module>(
403408
std::move(loader),
404409
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
405-
debug_buffer_size);
410+
debug_buffer_size,
411+
program_verification);
406412
}
407413

408414
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
@@ -452,30 +458,41 @@ struct PyModule final {
452458
explicit PyModule(
453459
const py::bytes& buffer,
454460
bool enable_etdump,
455-
size_t debug_buffer_size = 0)
461+
size_t debug_buffer_size = 0,
462+
Program::Verification program_verification =
463+
Program::Verification::InternalConsistency)
456464
: module_(load_module_from_buffer(
457465
buffer.cast<std::string_view>().data(),
458466
py::len(buffer),
459467
enable_etdump,
460-
debug_buffer_size)) {}
468+
debug_buffer_size,
469+
program_verification)) {}
461470

462471
explicit PyModule(
463472
const void* ptr,
464473
size_t ptr_len,
465474
bool enable_etdump,
466-
size_t debug_buffer_size = 0)
475+
size_t debug_buffer_size = 0,
476+
Program::Verification program_verification =
477+
Program::Verification::InternalConsistency)
467478
: module_(load_module_from_buffer(
468479
ptr,
469480
ptr_len,
470481
enable_etdump,
471-
debug_buffer_size)) {}
482+
debug_buffer_size,
483+
program_verification)) {}
472484

473485
explicit PyModule(
474486
const std::string& path,
475487
bool enable_etdump,
476-
size_t debug_buffer_size = 0)
477-
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
478-
}
488+
size_t debug_buffer_size = 0,
489+
Program::Verification program_verification =
490+
Program::Verification::InternalConsistency)
491+
: module_(load_module_from_file(
492+
path,
493+
enable_etdump,
494+
debug_buffer_size,
495+
program_verification)) {}
479496

480497
PyModule(const PyModule&) = delete;
481498
PyModule& operator=(const PyModule&) = delete;
@@ -486,14 +503,20 @@ struct PyModule final {
486503
static std::unique_ptr<PyModule> load_from_buffer(
487504
const py::bytes& buffer,
488505
bool enable_etdump,
489-
size_t debug_buffer_size = 0) {
490-
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
506+
size_t debug_buffer_size = 0,
507+
Program::Verification program_verification =
508+
Program::Verification::InternalConsistency) {
509+
return std::make_unique<PyModule>(
510+
buffer, enable_etdump, debug_buffer_size, program_verification);
491511
}
492512
static std::unique_ptr<PyModule> load_from_file(
493513
const std::string& path,
494514
bool enable_etdump,
495-
size_t debug_buffer_size = 0) {
496-
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
515+
size_t debug_buffer_size = 0,
516+
Program::Verification program_verification =
517+
Program::Verification::InternalConsistency) {
518+
return std::make_unique<PyModule>(
519+
path, enable_etdump, debug_buffer_size, program_verification);
497520
}
498521

499522
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -805,19 +828,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
805828
// Redirects cout and cerr for function calls this guards to the python env.
806829
auto call_guard = py::
807830
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
831+
832+
// Bind the verification enum to python.
833+
py::enum_<Program::Verification>(m, "Verification")
834+
.value("Minimal", Program::Verification::Minimal)
835+
.value("InternalConsistency", Program::Verification::InternalConsistency);
836+
808837
m.def(
809838
"_load_for_executorch",
810839
PyModule::load_from_file,
811840
py::arg("path"),
812841
py::arg("enable_etdump") = false,
813842
py::arg("debug_buffer_size") = 0,
843+
py::arg("program_verification") =
844+
Program::Verification::InternalConsistency,
814845
call_guard);
815846
m.def(
816847
"_load_for_executorch_from_buffer",
817848
&PyModule::load_from_buffer,
818849
py::arg("buffer"),
819850
py::arg("enable_etdump") = false,
820851
py::arg("debug_buffer_size") = 0,
852+
py::arg("program_verification") =
853+
Program::Verification::InternalConsistency,
821854
call_guard);
822855
m.def(
823856
"_load_for_executorch_from_bundled_program",

extension/pybindings/pybindings.pyi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,22 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
from typing import Any, Dict, List, Optional, Sequence, Tuple
8+
from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple
99

1010
from executorch.exir._warnings import experimental
1111

12+
@experimental("This API is experimental and subject to change without notice.")
13+
class Verification(Enum):
14+
"""Verification maps C++ Program::Verification to Python.
15+
16+
.. warning::
17+
18+
This API is experimental and subject to change without notice.
19+
"""
20+
21+
Minimal: ...
22+
InternalConsistency: ...
23+
1224
@experimental("This API is experimental and subject to change without notice.")
1325
class ExecuTorchModule:
1426
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
@@ -56,7 +68,10 @@ class BundledModule:
5668

5769
@experimental("This API is experimental and subject to change without notice.")
5870
def _load_for_executorch(
59-
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
71+
path: str,
72+
enable_etdump: bool = False,
73+
debug_buffer_size: int = 0,
74+
program_verification: Verification = Verification.InternalConsistency,
6075
) -> ExecuTorchModule:
6176
"""Load an ExecuTorch Program from a file.
6277
@@ -79,7 +94,10 @@ def _load_for_executorch(
7994

8095
@experimental("This API is experimental and subject to change without notice.")
8196
def _load_for_executorch_from_buffer(
82-
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
97+
buffer: bytes,
98+
enable_etdump: bool = False,
99+
debug_buffer_size: int = 0,
100+
program_verification: Verification = Verification.InternalConsistency,
83101
) -> ExecuTorchModule:
84102
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.
85103

extension/pybindings/test/make_test.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# pyre-unsafe
88

99
import unittest
10-
from typing import Any, Callable, Tuple
10+
from types import ModuleType
11+
from typing import Any, Callable, Optional, Tuple
1112

1213
import torch
1314
from executorch.exir import ExecutorchProgramManager, to_edge
@@ -16,7 +17,7 @@
1617

1718
def make_test( # noqa: C901
1819
tester: unittest.TestCase,
19-
load_fn: Callable,
20+
runtime: ModuleType,
2021
) -> Callable[[unittest.TestCase], None]:
2122
"""
2223
Returns a function that operates as a test case within a unittest.TestCase class.
@@ -25,6 +26,7 @@ def make_test( # noqa: C901
2526
which will all have different load functions. In this case each individual test case is a
2627
subfunction of wrapper.
2728
"""
29+
load_fn: Callable = runtime._load_for_executorch_from_buffer
2830

2931
def wrapper(tester: unittest.TestCase) -> None:
3032
class ModuleAdd(torch.nn.Module):
@@ -251,12 +253,37 @@ def test_quantized_ops(tester):
251253
expected = example_inputs[0] + example_inputs[1]
252254
tester.assertEqual(str(expected), str(executorch_output))
253255

256+
def test_verification_config(tester) -> None:
257+
# Create an ExecuTorch program from ModuleAdd.
258+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
259+
exported_program, inputs = create_program(ModuleAdd())
260+
Verification = runtime.Verification
261+
262+
# Use pybindings to load and execute the program.
263+
for config in [Verification.Minimal, Verification.InternalConsistency]:
264+
executorch_module = load_fn(
265+
exported_program.buffer,
266+
enable_etdump=False,
267+
debug_buffer_size=0,
268+
program_verification=config,
269+
)
270+
271+
executorch_output = executorch_module.forward(inputs)[0]
272+
273+
# The test module adds the two inputs, so its output should be the same
274+
# as adding them directly.
275+
expected = inputs[0] + inputs[1]
276+
277+
tester.assertEqual(str(expected), str(executorch_output))
278+
279+
######### RUN TEST CASES #########
254280
test_e2e(tester)
255281
test_multiple_entry(tester)
256282
test_output_lifespan(tester)
257283
test_module_callable(tester)
258284
test_module_single_input(tester)
259285
test_stderr_redirect(tester)
260286
test_quantized_ops(tester)
287+
test_verification_config(tester)
261288

262289
return wrapper

extension/pybindings/test/test_pybindings.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@
1010

1111
kernel_mode = None # either aten mode or portable mode
1212
try:
13-
from executorch.extension.pybindings.portable_lib import (
14-
_load_for_executorch_from_buffer,
15-
)
13+
from executorch.extension.pybindings import portable_lib as runtime
1614

1715
kernel_mode = "portable"
1816
except Exception:
1917
print("can't load portable lib")
2018

21-
try:
22-
from executorch.extension.pybindings.aten_lib import ( # noqa: F811
23-
_load_for_executorch_from_buffer,
24-
)
25-
26-
assert kernel_mode is None
19+
if kernel_mode is None:
20+
try:
21+
from executorch.extension.pybindings import aten_lib as runtime # noqa: F811
2722

28-
kernel_mode = "aten"
29-
except Exception:
30-
print("can't load aten lib")
23+
kernel_mode = "aten"
24+
except Exception:
25+
print("can't load aten lib")
3126

3227
assert kernel_mode is not None
3328

@@ -37,4 +32,4 @@
3732

3833
class PybindingsTest(unittest.TestCase):
3934
def test(self):
40-
make_test(self, _load_for_executorch_from_buffer)(self)
35+
make_test(self, runtime)(self)

0 commit comments

Comments
 (0)