8
8
import typing
9
9
from typing import Dict , List , Type
10
10
11
+ import executorch .bundled_program .schema as bp_schema
12
+ import executorch .exir .schema as core_schema
13
+
11
14
import torch
12
15
import torch .fx
13
16
from executorch .bundled_program .config import (
14
17
BundledConfig ,
15
18
ConfigExecutionPlanTest ,
16
19
ConfigValue ,
17
20
)
18
- from executorch .bundled_program .schema import (
19
- BundledBool ,
20
- BundledDouble ,
21
- BundledExecutionPlanTest ,
22
- BundledInt ,
23
- BundledIOSet ,
24
- BundledProgram ,
25
- BundledTensor ,
26
- BundledValue ,
27
- )
21
+
28
22
from executorch .bundled_program .version import BUNDLED_PROGRAM_SCHEMA_VERSION
29
23
from executorch .exir ._serialize import _serialize_pte_binary
30
- from executorch .exir .schema import (
31
- Bool ,
32
- Double ,
33
- ExecutionPlan ,
34
- Int ,
35
- KernelTypes ,
36
- Program ,
37
- Tensor ,
38
- )
24
+
39
25
from executorch .exir .tensor import get_scalar_type , scalar_type_enum , TensorSpec
40
26
41
27
# pyre-ignore
42
- supported_program_type_table : Dict [Type [KernelTypes ], ConfigValue ] = {
43
- Tensor : torch .Tensor ,
44
- Int : int ,
45
- Double : float ,
46
- Bool : bool ,
28
+ supported_program_type_table : Dict [Type [core_schema . KernelTypes ], ConfigValue ] = {
29
+ core_schema . Tensor : torch .Tensor ,
30
+ core_schema . Int : int ,
31
+ core_schema . Double : float ,
32
+ core_schema . Bool : bool ,
47
33
}
48
34
49
35
50
- def emit_bundled_tensor (spec : TensorSpec , bundled_values : List [BundledValue ]) -> None :
36
+ def emit_bundled_tensor (
37
+ spec : TensorSpec , bundled_values : List [bp_schema .Value ]
38
+ ) -> None :
51
39
# QuantizedSchema in tensor has deprecated and may not be used anymore.
52
40
# So here we don't emit it.
53
41
@@ -64,8 +52,8 @@ def emit_bundled_tensor(spec: TensorSpec, bundled_values: List[BundledValue]) ->
64
52
tensor_data : bytes = bytes (spec_array )
65
53
66
54
bundled_values .append (
67
- BundledValue (
68
- val = BundledTensor (
55
+ bp_schema . Value (
56
+ val = bp_schema . Tensor (
69
57
scalar_type = scalar_type_enum (spec .dtype ),
70
58
sizes = spec .shape ,
71
59
data = tensor_data ,
@@ -75,55 +63,67 @@ def emit_bundled_tensor(spec: TensorSpec, bundled_values: List[BundledValue]) ->
75
63
)
76
64
77
65
78
- def emit_prim (val : ConfigValue , bundled_values : List [BundledValue ]):
66
+ def emit_prim (val : ConfigValue , bundled_values : List [bp_schema . Value ]):
79
67
if type (val ) == int :
80
- bundled_values .append (BundledValue (val = BundledInt (int_val = val )))
68
+ bundled_values .append (bp_schema . Value (val = bp_schema . Int (int_val = val )))
81
69
elif type (val ) == bool :
82
- bundled_values .append (BundledValue (val = BundledBool (bool_val = val )))
70
+ bundled_values .append (bp_schema . Value (val = bp_schema . Bool (bool_val = val )))
83
71
elif type (val ) == float :
84
- bundled_values .append (BundledValue (val = BundledDouble (double_val = val )))
72
+ bundled_values .append (bp_schema . Value (val = bp_schema . Double (double_val = val )))
85
73
else :
86
74
assert 0 , "Unsupported primitive type received."
87
75
88
76
89
- def get_program_input (program : Program , plan_idx : int , input_idx : int ) -> KernelTypes :
77
+ def get_program_input (
78
+ program : core_schema .Program , plan_idx : int , input_idx : int
79
+ ) -> core_schema .KernelTypes :
90
80
return (
91
81
program .execution_plan [plan_idx ]
92
82
.values [program .execution_plan [plan_idx ].inputs [input_idx ]]
93
83
.val
94
84
)
95
85
96
86
97
- def get_program_output (program : Program , plan_idx : int , output_idx : int ) -> KernelTypes :
87
+ def get_program_output (
88
+ program : core_schema .Program , plan_idx : int , output_idx : int
89
+ ) -> core_schema .KernelTypes :
98
90
return (
99
91
program .execution_plan [plan_idx ]
100
92
.values [program .execution_plan [plan_idx ].outputs [output_idx ]]
101
93
.val
102
94
)
103
95
104
96
105
- def get_input_dtype (program : Program , plan_idx : int , input_idx : int ) -> torch .dtype :
97
+ def get_input_dtype (
98
+ program : core_schema .Program , plan_idx : int , input_idx : int
99
+ ) -> torch .dtype :
106
100
# pyre-fixme[16]: now assert all input and outputs is in tenor type. Support multuple datatypes in the future.
107
101
return get_scalar_type (get_program_input (program , plan_idx , input_idx ).scalar_type )
108
102
109
103
110
- def get_input_type (program : Program , plan_idx : int , input_idx : int ) -> type :
111
- type_lookup = {Int : int , Bool : bool , Double : float }
104
+ def get_input_type (program : core_schema .Program , plan_idx : int , input_idx : int ) -> type :
105
+ type_lookup = {
106
+ core_schema .Int : int ,
107
+ core_schema .Bool : bool ,
108
+ core_schema .Double : float ,
109
+ }
112
110
# pyre-fixme[6]: Incompatible parameter type [6]: In call `dict.__getitem__`, for 1st positional only parameter
113
- # expected `Type[Union[Bool, Double, Int]]` but got `Type[Union[Bool, Double, Int, Tensor, BoolList, DoubleList,
111
+ # expected `Type[Union[core_schema. Bool, core_schema. Double, core_schema. Int]]` but got `Type[Union[core_schema. Bool, core_schema. Double, core_schema. Int, core_schema. Tensor, BoolList, DoubleList,
114
112
# IntList, Null, OptionalTensorList, String, TensorList]]`.
115
113
return type_lookup [type (get_program_input (program , plan_idx , input_idx ))]
116
114
117
115
118
- def get_output_dtype (program : Program , plan_idx : int , output_idx : int ) -> torch .dtype :
116
+ def get_output_dtype (
117
+ program : core_schema .Program , plan_idx : int , output_idx : int
118
+ ) -> torch .dtype :
119
119
return get_scalar_type (
120
120
# pyre-ignore[16]: now assert all outputs is in tensor type.
121
121
get_program_output (program , plan_idx , output_idx ).scalar_type
122
122
)
123
123
124
124
125
125
def assert_valid_bundle (
126
- program : Program ,
126
+ program : core_schema . Program ,
127
127
bundled_config : BundledConfig ,
128
128
) -> None :
129
129
"""Check if the program and BundledConfig matches each other.
@@ -163,7 +163,7 @@ def assert_valid_bundle(
163
163
plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
164
164
bp_plan_id
165
165
]
166
- plan : ExecutionPlan = program .execution_plan [program_plan_id ]
166
+ plan : core_schema . ExecutionPlan = program .execution_plan [program_plan_id ]
167
167
168
168
# User does not provide testcases for current plan, skip it
169
169
if plan_test .method_name > plan .name :
@@ -185,7 +185,8 @@ def assert_valid_bundle(
185
185
# Check if the type of Program's output is supported
186
186
for index in range (len (plan .outputs )):
187
187
assert (
188
- type (get_program_output (program , program_plan_id , index )) == Tensor
188
+ type (get_program_output (program , program_plan_id , index ))
189
+ == core_schema .Tensor
189
190
), "Only supports program with output in Tensor type."
190
191
191
192
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -257,10 +258,10 @@ def assert_valid_bundle(
257
258
258
259
259
260
def create_bundled_program (
260
- program : Program ,
261
+ program : core_schema . Program ,
261
262
bundled_config : BundledConfig ,
262
- ) -> BundledProgram :
263
- """Create BundledProgram by bundling the given program and bundled_config together.
263
+ ) -> bp_schema . BundledProgram :
264
+ """Create bp_schema. BundledProgram by bundling the given program and bundled_config together.
264
265
265
266
Args:
266
267
program: The program to be bundled.
@@ -269,16 +270,16 @@ def create_bundled_program(
269
270
270
271
assert_valid_bundle (program , bundled_config )
271
272
272
- execution_plan_tests : List [BundledExecutionPlanTest ] = []
273
+ execution_plan_tests : List [bp_schema . BundledExecutionPlanTest ] = []
273
274
274
275
# Emit data and metadata of bundled tensor
275
276
for plan_test in bundled_config .execution_plan_tests :
276
- test_sets : List [BundledIOSet ] = []
277
+ test_sets : List [bp_schema . BundledIOSet ] = []
277
278
278
279
# emit I/O sets for each execution plan test
279
280
for i in range (len (plan_test .test_sets )):
280
- inputs : List [BundledValue ] = []
281
- expected_outputs : List [BundledValue ] = []
281
+ inputs : List [bp_schema . Value ] = []
282
+ expected_outputs : List [bp_schema . Value ] = []
282
283
283
284
cur_plan_test_inputs = plan_test .test_sets [i ].inputs
284
285
cur_plan_test_expected_outputs = plan_test .test_sets [i ].expected_outputs
@@ -303,19 +304,19 @@ def create_bundled_program(
303
304
expected_outputs ,
304
305
)
305
306
test_sets .append (
306
- BundledIOSet (inputs = inputs , expected_outputs = expected_outputs )
307
+ bp_schema . BundledIOSet (inputs = inputs , expected_outputs = expected_outputs )
307
308
)
308
309
309
310
# emit the whole execution plan test
310
311
execution_plan_tests .append (
311
- BundledExecutionPlanTest (
312
+ bp_schema . BundledExecutionPlanTest (
312
313
method_name = plan_test .method_name , test_sets = test_sets
313
314
)
314
315
)
315
316
316
317
program_bytes : bytes = _serialize_pte_binary (program )
317
318
318
- return BundledProgram (
319
+ return bp_schema . BundledProgram (
319
320
version = BUNDLED_PROGRAM_SCHEMA_VERSION ,
320
321
execution_plan_tests = execution_plan_tests ,
321
322
program = program_bytes ,
0 commit comments