Skip to content

Commit 8fc3f8c

Browse files
authored
Support to init BundledProgram from pte file
Differential Revision: D67013542 Pull Request resolved: #7278
1 parent 62d2e37 commit 8fc3f8c

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

devtools/bundled_program/core.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, List, Optional, Sequence, Type, Union
1010

1111
import executorch.devtools.bundled_program.schema as bp_schema
12+
from pyre_extensions import none_throws
1213

1314
import executorch.exir.schema as core_schema
1415

@@ -19,7 +20,7 @@
1920
from executorch.devtools.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION
2021

2122
from executorch.exir import ExecutorchProgram, ExecutorchProgramManager
22-
from executorch.exir._serialize import _serialize_pte_binary
23+
from executorch.exir._serialize import _deserialize_pte_binary, _serialize_pte_binary
2324
from executorch.exir.tensor import get_scalar_type, scalar_type_enum, TensorSpec
2425

2526
# pyre-ignore
@@ -43,23 +44,35 @@ class BundledProgram:
4344

4445
def __init__(
4546
self,
46-
executorch_program: Union[
47+
executorch_program: Optional[Union[
4748
ExecutorchProgram,
4849
ExecutorchProgramManager,
49-
],
50+
]],
5051
method_test_suites: Sequence[MethodTestSuite],
52+
pte_file_path: Optional[str] = None,
5153
):
5254
"""Create BundledProgram by bundling the given program and method_test_suites together.
5355
5456
Args:
5557
executorch_program: The program to be bundled.
5658
method_test_suites: The testcases for certain methods to be bundled.
59+
pte_file_path: The path to pte file to deserialize program if executorch_program is not provided.
5760
"""
61+
if not executorch_program and not pte_file_path:
62+
raise RuntimeError("Either executorch_program or pte_file_path must be provided")
63+
64+
if executorch_program and pte_file_path:
65+
raise RuntimeError("Only one of executorch_program or pte_file_path can be used")
5866

5967
method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
60-
self._assert_valid_bundle(executorch_program, method_test_suites)
68+
if executorch_program:
69+
self._assert_valid_bundle(executorch_program, method_test_suites)
70+
self.executorch_program: Optional[Union[
71+
ExecutorchProgram,
72+
ExecutorchProgramManager,
73+
]] = executorch_program
74+
self._pte_file_path: Optional[str] = pte_file_path
6175

62-
self.executorch_program = executorch_program
6376
self.method_test_suites = method_test_suites
6477

6578
# This is the cache for bundled program in schema type.
@@ -72,7 +85,13 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
7285
if self._bundled_program_in_schema is not None:
7386
return self._bundled_program_in_schema
7487

75-
program = self._extract_program(self.executorch_program)
88+
if self.executorch_program:
89+
program = self._extract_program(self.executorch_program)
90+
else:
91+
with open(none_throws(self._pte_file_path), "rb") as f:
92+
p_bytes = f.read()
93+
program = _deserialize_pte_binary(p_bytes)
94+
7695
bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []
7796

7897
# Emit data and metadata of bundled tensor

devtools/bundled_program/test/test_bundle_data.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import unittest
1010
from typing import List
11-
11+
import tempfile
1212
import executorch.devtools.bundled_program.schema as bp_schema
1313

1414
import torch
@@ -73,6 +73,43 @@ def test_bundled_program(self) -> None:
7373
bundled_program.serialize_to_schema().program,
7474
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
7575
)
76+
77+
def test_bundled_program_from_pte(self) -> None:
78+
executorch_program, method_test_suites = get_common_executorch_program()
79+
80+
with tempfile.TemporaryDirectory() as tmp_dir:
81+
executorch_model_path = f"{tmp_dir}/executorch_model.pte"
82+
with open(executorch_model_path, "wb") as f:
83+
f.write(executorch_program.buffer)
84+
85+
bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)
86+
87+
method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)
88+
89+
for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
90+
bundled_plan_test = (
91+
bundled_program.serialize_to_schema().method_test_suites[plan_id]
92+
)
93+
method_test_suite = method_test_suites[plan_id]
94+
95+
self.assertEqual(
96+
len(bundled_plan_test.test_cases), len(method_test_suite.test_cases)
97+
)
98+
for bundled_program_ioset, method_test_case in zip(
99+
bundled_plan_test.test_cases, method_test_suite.test_cases
100+
):
101+
self.assertIOsetDataEqual(
102+
bundled_program_ioset.inputs, method_test_case.inputs
103+
)
104+
self.assertIOsetDataEqual(
105+
bundled_program_ioset.expected_outputs,
106+
method_test_case.expected_outputs,
107+
)
108+
109+
self.assertEqual(
110+
bundled_program.serialize_to_schema().program,
111+
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
112+
)
76113

77114
def test_bundled_miss_methods(self) -> None:
78115
executorch_program, method_test_suites = get_common_executorch_program()

0 commit comments

Comments
 (0)