Skip to content

Commit da1d2ca

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add mapping from C++ program::verification to Python (#5915)
Summary: As titled. This enables `portable_lib._load_for_executorch[_from_buffer]` to accept `Program::Verification` argument. See added test, now we can do something like: ``` from executorch.extension.pybindings.portable_lib import Verification module = load_fn( exported_program.buffer, enable_etdump=False, debug_buffer_size=0, program_verification=Verification.Minimal, ) ``` Pull Request resolved: #5915 Test Plan: See unit test Reviewed By: dbort Differential Revision: D63987538 Pulled By: larryliu0820 fbshipit-source-id: b68d8d1149e2d46b90544679707f420179e72b19
1 parent a6b213b commit da1d2ca

File tree

5 files changed

+108
-35
lines changed

5 files changed

+108
-35
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
Verification, # noqa: F401
4849
)
4950

5051
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`

extension/pybindings/pybindings.cpp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ class Module final {
168168
explicit Module(
169169
std::unique_ptr<DataLoader> loader,
170170
std::unique_ptr<ETDumpGen> tracer = nullptr,
171-
size_t debug_buffer_size = 0)
171+
size_t debug_buffer_size = 0,
172+
Program::Verification program_verification =
173+
Program::Verification::InternalConsistency)
172174
: loader_(std::move(loader)),
173175
event_tracer_(std::move(tracer)),
174176
debug_buffer_size_(debug_buffer_size) {
175177
::executorch::runtime::runtime_init();
176-
Result<Program> program = Program::load(
177-
loader_.get(), Program::Verification::InternalConsistency);
178+
Result<Program> program =
179+
Program::load(loader_.get(), program_verification);
178180
THROW_IF_ERROR(
179181
program.error(),
180182
"loading program failed with error: 0x%" PRIx32,
@@ -388,19 +390,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
388390
const void* ptr,
389391
size_t ptr_len,
390392
bool enable_etdump,
391-
size_t debug_buffer_size) {
393+
size_t debug_buffer_size,
394+
Program::Verification program_verification) {
392395
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
393396
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
394397
return std::make_unique<Module>(
395398
std::move(loader),
396399
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
397-
debug_buffer_size);
400+
debug_buffer_size,
401+
program_verification);
398402
}
399403

400404
inline std::unique_ptr<Module> load_module_from_file(
401405
const std::string& path,
402406
bool enable_etdump,
403-
size_t debug_buffer_size) {
407+
size_t debug_buffer_size,
408+
Program::Verification program_verification) {
404409
EXECUTORCH_SCOPE_PROF("load_module_from_file");
405410

406411
Result<MmapDataLoader> res = MmapDataLoader::from(
@@ -415,7 +420,8 @@ inline std::unique_ptr<Module> load_module_from_file(
415420
return std::make_unique<Module>(
416421
std::move(loader),
417422
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
418-
debug_buffer_size);
423+
debug_buffer_size,
424+
program_verification);
419425
}
420426

421427
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
@@ -578,30 +584,41 @@ struct PyModule final {
578584
explicit PyModule(
579585
const py::bytes& buffer,
580586
bool enable_etdump,
581-
size_t debug_buffer_size = 0)
587+
size_t debug_buffer_size = 0,
588+
Program::Verification program_verification =
589+
Program::Verification::InternalConsistency)
582590
: module_(load_module_from_buffer(
583591
buffer.cast<std::string_view>().data(),
584592
py::len(buffer),
585593
enable_etdump,
586-
debug_buffer_size)) {}
594+
debug_buffer_size,
595+
program_verification)) {}
587596

588597
explicit PyModule(
589598
const void* ptr,
590599
size_t ptr_len,
591600
bool enable_etdump,
592-
size_t debug_buffer_size = 0)
601+
size_t debug_buffer_size = 0,
602+
Program::Verification program_verification =
603+
Program::Verification::InternalConsistency)
593604
: module_(load_module_from_buffer(
594605
ptr,
595606
ptr_len,
596607
enable_etdump,
597-
debug_buffer_size)) {}
608+
debug_buffer_size,
609+
program_verification)) {}
598610

599611
explicit PyModule(
600612
const std::string& path,
601613
bool enable_etdump,
602-
size_t debug_buffer_size = 0)
603-
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
604-
}
614+
size_t debug_buffer_size = 0,
615+
Program::Verification program_verification =
616+
Program::Verification::InternalConsistency)
617+
: module_(load_module_from_file(
618+
path,
619+
enable_etdump,
620+
debug_buffer_size,
621+
program_verification)) {}
605622

606623
PyModule(const PyModule&) = delete;
607624
PyModule& operator=(const PyModule&) = delete;
@@ -612,14 +629,20 @@ struct PyModule final {
612629
static std::unique_ptr<PyModule> load_from_buffer(
613630
const py::bytes& buffer,
614631
bool enable_etdump,
615-
size_t debug_buffer_size = 0) {
616-
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
632+
size_t debug_buffer_size = 0,
633+
Program::Verification program_verification =
634+
Program::Verification::InternalConsistency) {
635+
return std::make_unique<PyModule>(
636+
buffer, enable_etdump, debug_buffer_size, program_verification);
617637
}
618638
static std::unique_ptr<PyModule> load_from_file(
619639
const std::string& path,
620640
bool enable_etdump,
621-
size_t debug_buffer_size = 0) {
622-
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
641+
size_t debug_buffer_size = 0,
642+
Program::Verification program_verification =
643+
Program::Verification::InternalConsistency) {
644+
return std::make_unique<PyModule>(
645+
path, enable_etdump, debug_buffer_size, program_verification);
623646
}
624647

625648
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -944,19 +967,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
944967
// Redirects cout and cerr for function calls this guards to the python env.
945968
auto call_guard = py::
946969
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
970+
971+
// Bind the verification enum to python.
972+
py::enum_<Program::Verification>(m, "Verification")
973+
.value("Minimal", Program::Verification::Minimal)
974+
.value("InternalConsistency", Program::Verification::InternalConsistency);
975+
947976
m.def(
948977
"_load_for_executorch",
949978
PyModule::load_from_file,
950979
py::arg("path"),
951980
py::arg("enable_etdump") = false,
952981
py::arg("debug_buffer_size") = 0,
982+
py::arg("program_verification") =
983+
Program::Verification::InternalConsistency,
953984
call_guard);
954985
m.def(
955986
"_load_for_executorch_from_buffer",
956987
&PyModule::load_from_buffer,
957988
py::arg("buffer"),
958989
py::arg("enable_etdump") = false,
959990
py::arg("debug_buffer_size") = 0,
991+
py::arg("program_verification") =
992+
Program::Verification::InternalConsistency,
960993
call_guard);
961994
m.def(
962995
"_load_for_executorch_from_bundled_program",

extension/pybindings/pybindings.pyi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@
77
# pyre-strict
88
from __future__ import annotations
99

10-
from typing import Any, Dict, List, Optional, Sequence, Tuple
10+
from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple
1111

1212
from executorch.exir._warnings import experimental
1313

14+
@experimental("This API is experimental and subject to change without notice.")
15+
class Verification(Enum):
16+
"""Verification maps C++ Program::Verification to Python.
17+
18+
.. warning::
19+
20+
This API is experimental and subject to change without notice.
21+
"""
22+
23+
Minimal: ...
24+
InternalConsistency: ...
25+
1426
@experimental("This API is experimental and subject to change without notice.")
1527
class ExecuTorchModule:
1628
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
@@ -125,7 +137,10 @@ class MethodMeta:
125137

126138
@experimental("This API is experimental and subject to change without notice.")
127139
def _load_for_executorch(
128-
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
140+
path: str,
141+
enable_etdump: bool = False,
142+
debug_buffer_size: int = 0,
143+
program_verification: Verification = Verification.InternalConsistency,
129144
) -> ExecuTorchModule:
130145
"""Load an ExecuTorch Program from a file.
131146
@@ -148,7 +163,10 @@ def _load_for_executorch(
148163

149164
@experimental("This API is experimental and subject to change without notice.")
150165
def _load_for_executorch_from_buffer(
151-
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
166+
buffer: bytes,
167+
enable_etdump: bool = False,
168+
debug_buffer_size: int = 0,
169+
program_verification: Verification = Verification.InternalConsistency,
152170
) -> ExecuTorchModule:
153171
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.
154172

extension/pybindings/test/make_test.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import unittest
10+
from types import ModuleType
1011
from typing import Any, Callable, Optional, Tuple
1112

1213
import torch
@@ -17,7 +18,7 @@
1718

1819
def make_test( # noqa: C901
1920
tester: unittest.TestCase,
20-
load_fn: Callable,
21+
runtime: ModuleType,
2122
) -> Callable[[unittest.TestCase], None]:
2223
"""
2324
Returns a function that operates as a test case within a unittest.TestCase class.
@@ -26,6 +27,7 @@ def make_test( # noqa: C901
2627
which will all have different load functions. In this case each individual test case is a
2728
subfunction of wrapper.
2829
"""
30+
load_fn: Callable = runtime._load_for_executorch_from_buffer
2931

3032
def wrapper(tester: unittest.TestCase) -> None:
3133
class ModuleAdd(torch.nn.Module):
@@ -352,6 +354,29 @@ def test_bad_name(tester) -> None:
352354
with tester.assertRaises(RuntimeError):
353355
executorch_module.run_method("not_a_real_method", inputs)
354356

357+
def test_verification_config(tester) -> None:
358+
# Create an ExecuTorch program from ModuleAdd.
359+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
360+
exported_program, inputs = create_program(ModuleAdd())
361+
Verification = runtime.Verification
362+
363+
# Use pybindings to load and execute the program.
364+
for config in [Verification.Minimal, Verification.InternalConsistency]:
365+
executorch_module = load_fn(
366+
exported_program.buffer,
367+
enable_etdump=False,
368+
debug_buffer_size=0,
369+
program_verification=config,
370+
)
371+
372+
executorch_output = executorch_module.forward(inputs)[0]
373+
374+
# The test module adds the two inputs, so its output should be the same
375+
# as adding them directly.
376+
expected = inputs[0] + inputs[1]
377+
378+
tester.assertEqual(str(expected), str(executorch_output))
379+
355380
######### RUN TEST CASES #########
356381
test_e2e(tester)
357382
test_multiple_entry(tester)
@@ -363,5 +388,6 @@ def test_bad_name(tester) -> None:
363388
test_constant_output_not_memory_planned(tester)
364389
test_method_meta(tester)
365390
test_bad_name(tester)
391+
test_verification_config(tester)
366392

367393
return wrapper

extension/pybindings/test/test_pybindings.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@
1010

1111
kernel_mode = None # either aten mode or portable mode
1212
try:
13-
from executorch.extension.pybindings.portable_lib import (
14-
_load_for_executorch_from_buffer,
15-
)
13+
from executorch.extension.pybindings import portable_lib as runtime
1614

1715
kernel_mode = "portable"
1816
except Exception:
1917
print("can't load portable lib")
2018

21-
try:
22-
from executorch.extension.pybindings.aten_lib import ( # noqa: F811
23-
_load_for_executorch_from_buffer,
24-
)
25-
26-
assert kernel_mode is None
19+
if kernel_mode is None:
20+
try:
21+
from executorch.extension.pybindings import aten_lib as runtime # noqa: F811
2722

28-
kernel_mode = "aten"
29-
except Exception:
30-
print("can't load aten lib")
23+
kernel_mode = "aten"
24+
except Exception:
25+
print("can't load aten lib")
3126

3227
assert kernel_mode is not None
3328

@@ -37,4 +32,4 @@
3732

3833
class PybindingsTest(unittest.TestCase):
3934
def test(self):
40-
make_test(self, _load_for_executorch_from_buffer)(self)
35+
make_test(self, runtime)(self)

0 commit comments

Comments
 (0)