Skip to content

Commit d0d1d44

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Support to init BundledProgram from pte file
Summary: Added an optional `pte_file_path` arg to BundledProgram's init function. This is to allow users to create bundled program with varied inputs after exporting. Differential Revision: D67013542
1 parent 343aa0c commit d0d1d44

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

devtools/bundled_program/core.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, List, Optional, Sequence, Type, Union
1010

1111
import executorch.devtools.bundled_program.schema as bp_schema
12+
from pyre_extensions import none_throws
1213

1314
import executorch.exir.schema as core_schema
1415

@@ -19,7 +20,8 @@
1920
from executorch.devtools.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION
2021

2122
from executorch.exir import ExecutorchProgram, ExecutorchProgramManager
22-
from executorch.exir._serialize import _serialize_pte_binary
23+
from executorch.exir._serialize import _deserialize_pte_binary, _serialize_pte_binary
24+
from executorch.exir.schema import Program
2325
from executorch.exir.tensor import get_scalar_type, scalar_type_enum, TensorSpec
2426

2527
# pyre-ignore
@@ -43,23 +45,31 @@ class BundledProgram:
4345

4446
def __init__(
4547
self,
46-
executorch_program: Union[
48+
executorch_program: Optional[Union[
4749
ExecutorchProgram,
4850
ExecutorchProgramManager,
49-
],
51+
]],
5052
method_test_suites: Sequence[MethodTestSuite],
53+
pte_file_path: Optional[str] = None,
5154
):
5255
"""Create BundledProgram by bundling the given program and method_test_suites together.
5356
5457
Args:
5558
executorch_program: The program to be bundled.
5659
method_test_suites: The testcases for certain methods to be bundled.
5760
"""
61+
if not executorch_program and not pte_file_path:
62+
raise RuntimeError("Either executorch_program or pte_file_path must be provided")
5863

5964
method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
60-
self._assert_valid_bundle(executorch_program, method_test_suites)
65+
if executorch_program:
66+
self._assert_valid_bundle(executorch_program, method_test_suites)
67+
self.executorch_program: Optional[Union[
68+
ExecutorchProgram,
69+
ExecutorchProgramManager,
70+
]] = executorch_program
71+
self._pte_file_path: Optional[str] = pte_file_path
6172

62-
self.executorch_program = executorch_program
6373
self.method_test_suites = method_test_suites
6474

6575
# This is the cache for bundled program in schema type.
@@ -72,7 +82,13 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
7282
if self._bundled_program_in_schema is not None:
7383
return self._bundled_program_in_schema
7484

75-
program = self._extract_program(self.executorch_program)
85+
if self.executorch_program:
86+
program = self._extract_program(self.executorch_program)
87+
else:
88+
with open(none_throws(self._pte_file_path), "rb") as f:
89+
p_bytes = f.read()
90+
program = _deserialize_pte_binary(p_bytes)
91+
7692
bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []
7793

7894
# Emit data and metadata of bundled tensor

devtools/bundled_program/test/test_bundle_data.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import unittest
1010
from typing import List
11-
11+
import tempfile
1212
import executorch.devtools.bundled_program.schema as bp_schema
1313

1414
import torch
@@ -73,6 +73,43 @@ def test_bundled_program(self) -> None:
7373
bundled_program.serialize_to_schema().program,
7474
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
7575
)
76+
77+
def test_bundled_program_from_pte(self) -> None:
78+
executorch_program, method_test_suites = get_common_executorch_program()
79+
80+
with tempfile.TemporaryDirectory() as tmp_dir:
81+
executorch_model_path = f"{tmp_dir}/executorch_model.pte"
82+
with open(executorch_model_path, "wb") as f:
83+
f.write(executorch_program.buffer)
84+
85+
bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)
86+
87+
method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)
88+
89+
for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
90+
bundled_plan_test = (
91+
bundled_program.serialize_to_schema().method_test_suites[plan_id]
92+
)
93+
method_test_suite = method_test_suites[plan_id]
94+
95+
self.assertEqual(
96+
len(bundled_plan_test.test_cases), len(method_test_suite.test_cases)
97+
)
98+
for bundled_program_ioset, method_test_case in zip(
99+
bundled_plan_test.test_cases, method_test_suite.test_cases
100+
):
101+
self.assertIOsetDataEqual(
102+
bundled_program_ioset.inputs, method_test_case.inputs
103+
)
104+
self.assertIOsetDataEqual(
105+
bundled_program_ioset.expected_outputs,
106+
method_test_case.expected_outputs,
107+
)
108+
109+
self.assertEqual(
110+
bundled_program.serialize_to_schema().program,
111+
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
112+
)
76113

77114
def test_bundled_miss_methods(self) -> None:
78115
executorch_program, method_test_suites = get_common_executorch_program()

0 commit comments

Comments
 (0)