9
9
from typing import Dict , List , Optional , Sequence , Type , Union
10
10
11
11
import executorch .devtools .bundled_program .schema as bp_schema
12
+ from pyre_extensions import none_throws
12
13
13
14
import executorch .exir .schema as core_schema
14
15
19
20
from executorch .devtools .bundled_program .version import BUNDLED_PROGRAM_SCHEMA_VERSION
20
21
21
22
from executorch .exir import ExecutorchProgram , ExecutorchProgramManager
22
- from executorch .exir ._serialize import _serialize_pte_binary
23
+ from executorch .exir ._serialize import _deserialize_pte_binary , _serialize_pte_binary
24
+ from executorch .exir .schema import Program
23
25
from executorch .exir .tensor import get_scalar_type , scalar_type_enum , TensorSpec
24
26
25
27
# pyre-ignore
@@ -43,23 +45,31 @@ class BundledProgram:
43
45
44
46
def __init__ (
45
47
self ,
46
- executorch_program : Union [
48
+ executorch_program : Optional [ Union [
47
49
ExecutorchProgram ,
48
50
ExecutorchProgramManager ,
49
- ],
51
+ ]] ,
50
52
method_test_suites : Sequence [MethodTestSuite ],
53
+ pte_file_path : Optional [str ] = None ,
51
54
):
52
55
"""Create BundledProgram by bundling the given program and method_test_suites together.
53
56
54
57
Args:
55
58
executorch_program: The program to be bundled.
56
59
method_test_suites: The testcases for certain methods to be bundled.
57
60
"""
61
+ if not executorch_program and not pte_file_path :
62
+ raise RuntimeError ("Either executorch_program or pte_file_path must be provided" )
58
63
59
64
method_test_suites = sorted (method_test_suites , key = lambda x : x .method_name )
60
- self ._assert_valid_bundle (executorch_program , method_test_suites )
65
+ if executorch_program :
66
+ self ._assert_valid_bundle (executorch_program , method_test_suites )
67
+ self .executorch_program : Optional [Union [
68
+ ExecutorchProgram ,
69
+ ExecutorchProgramManager ,
70
+ ]] = executorch_program
71
+ self ._pte_file_path : Optional [str ] = pte_file_path
61
72
62
- self .executorch_program = executorch_program
63
73
self .method_test_suites = method_test_suites
64
74
65
75
# This is the cache for bundled program in schema type.
@@ -72,7 +82,13 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
72
82
if self ._bundled_program_in_schema is not None :
73
83
return self ._bundled_program_in_schema
74
84
75
- program = self ._extract_program (self .executorch_program )
85
+ if self .executorch_program :
86
+ program = self ._extract_program (self .executorch_program )
87
+ else :
88
+ with open (none_throws (self ._pte_file_path ), "rb" ) as f :
89
+ p_bytes = f .read ()
90
+ program = _deserialize_pte_binary (p_bytes )
91
+
76
92
bundled_method_test_suites : List [bp_schema .BundledMethodTestSuite ] = []
77
93
78
94
# Emit data and metadata of bundled tensor
0 commit comments