8
8
import typing
9
9
from typing import Dict , List , Type
10
10
11
+ import executorch .bundled_program .schema as bp_schema
12
+ import executorch .exir .schema as core_schema
13
+
11
14
import torch
12
15
import torch .fx
13
16
from executorch .bundled_program .config import (
14
17
BundledConfig ,
15
18
ConfigExecutionPlanTest ,
16
19
ConfigValue ,
17
20
)
18
- from executorch .bundled_program .schema import (
19
- BundledBool ,
20
- BundledDouble ,
21
- BundledExecutionPlanTest ,
22
- BundledInt ,
23
- BundledIOSet ,
24
- BundledProgram ,
25
- BundledTensor ,
26
- BundledValue ,
27
- )
21
+
28
22
from executorch .bundled_program .version import BUNDLED_PROGRAM_SCHEMA_VERSION
29
23
from executorch .exir ._serialize import _serialize_pte_binary
30
- from executorch .exir .schema import (
31
- Bool ,
32
- Double ,
33
- ExecutionPlan ,
34
- Int ,
35
- KernelTypes ,
36
- Program ,
37
- Tensor ,
38
- )
24
+
39
25
from executorch .exir .tensor import get_scalar_type , scalar_type_enum , TensorSpec
40
26
41
27
# pyre-ignore
42
- supported_program_type_table : Dict [Type [KernelTypes ], ConfigValue ] = {
43
- Tensor : torch .Tensor ,
44
- Int : int ,
45
- Double : float ,
46
- Bool : bool ,
28
+ supported_program_type_table : Dict [Type [core_schema . KernelTypes ], ConfigValue ] = {
29
+ core_schema . Tensor : torch .Tensor ,
30
+ core_schema . Int : int ,
31
+ core_schema . Double : float ,
32
+ core_schema . Bool : bool ,
47
33
}
48
34
49
35
50
- def emit_bundled_tensor (spec : TensorSpec , bundled_values : List [BundledValue ]) -> None :
36
+ def emit_bundled_tensor (
37
+ spec : TensorSpec , bundled_values : List [bp_schema .Value ]
38
+ ) -> None :
51
39
# QuantizedSchema in tensor has deprecated and may not be used anymore.
52
40
# So here we don't emit it.
53
41
@@ -64,8 +52,8 @@ def emit_bundled_tensor(spec: TensorSpec, bundled_values: List[BundledValue]) ->
64
52
tensor_data : bytes = bytes (spec_array )
65
53
66
54
bundled_values .append (
67
- BundledValue (
68
- val = BundledTensor (
55
+ bp_schema . Value (
56
+ val = bp_schema . Tensor (
69
57
scalar_type = scalar_type_enum (spec .dtype ),
70
58
sizes = spec .shape ,
71
59
data = tensor_data ,
@@ -75,55 +63,67 @@ def emit_bundled_tensor(spec: TensorSpec, bundled_values: List[BundledValue]) ->
75
63
)
76
64
77
65
78
- def emit_prim (val : ConfigValue , bundled_values : List [BundledValue ]):
66
+ def emit_prim (val : ConfigValue , bundled_values : List [bp_schema . Value ]):
79
67
if type (val ) == int :
80
- bundled_values .append (BundledValue (val = BundledInt (int_val = val )))
68
+ bundled_values .append (bp_schema . Value (val = bp_schema . Int (int_val = val )))
81
69
elif type (val ) == bool :
82
- bundled_values .append (BundledValue (val = BundledBool (bool_val = val )))
70
+ bundled_values .append (bp_schema . Value (val = bp_schema . Bool (bool_val = val )))
83
71
elif type (val ) == float :
84
- bundled_values .append (BundledValue (val = BundledDouble (double_val = val )))
72
+ bundled_values .append (bp_schema . Value (val = bp_schema . Double (double_val = val )))
85
73
else :
86
74
assert 0 , "Unsupported primitive type received."
87
75
88
76
89
- def get_program_input (program : Program , plan_idx : int , input_idx : int ) -> KernelTypes :
77
+ def get_program_input (
78
+ program : core_schema .Program , plan_idx : int , input_idx : int
79
+ ) -> core_schema .KernelTypes :
90
80
return (
91
81
program .execution_plan [plan_idx ]
92
82
.values [program .execution_plan [plan_idx ].inputs [input_idx ]]
93
83
.val
94
84
)
95
85
96
86
97
- def get_program_output (program : Program , plan_idx : int , output_idx : int ) -> KernelTypes :
87
+ def get_program_output (
88
+ program : core_schema .Program , plan_idx : int , output_idx : int
89
+ ) -> core_schema .KernelTypes :
98
90
return (
99
91
program .execution_plan [plan_idx ]
100
92
.values [program .execution_plan [plan_idx ].outputs [output_idx ]]
101
93
.val
102
94
)
103
95
104
96
105
- def get_input_dtype (program : Program , plan_idx : int , input_idx : int ) -> torch .dtype :
97
+ def get_input_dtype (
98
+ program : core_schema .Program , plan_idx : int , input_idx : int
99
+ ) -> torch .dtype :
106
100
# pyre-fixme[16]: now assert all input and outputs is in tenor type. Support multuple datatypes in the future.
107
101
return get_scalar_type (get_program_input (program , plan_idx , input_idx ).scalar_type )
108
102
109
103
110
- def get_input_type (program : Program , plan_idx : int , input_idx : int ) -> type :
111
- type_lookup = {Int : int , Bool : bool , Double : float }
104
+ def get_input_type (program : core_schema .Program , plan_idx : int , input_idx : int ) -> type :
105
+ type_lookup = {
106
+ core_schema .Int : int ,
107
+ core_schema .Bool : bool ,
108
+ core_schema .Double : float ,
109
+ }
112
110
# pyre-fixme[6]: Incompatible parameter type [6]: In call `dict.__getitem__`, for 1st positional only parameter
113
- # expected `Type[Union[Bool, Double, Int]]` but got `Type[Union[Bool, Double, Int, Tensor, BoolList, DoubleList,
111
+ # expected `Type[Union[core_schema. Bool, core_schema. Double, core_schema. Int]]` but got `Type[Union[core_schema. Bool, core_schema. Double, core_schema. Int, core_schema. Tensor, BoolList, DoubleList,
114
112
# IntList, Null, OptionalTensorList, String, TensorList]]`.
115
113
return type_lookup [type (get_program_input (program , plan_idx , input_idx ))]
116
114
117
115
118
- def get_output_dtype (program : Program , plan_idx : int , output_idx : int ) -> torch .dtype :
116
+ def get_output_dtype (
117
+ program : core_schema .Program , plan_idx : int , output_idx : int
118
+ ) -> torch .dtype :
119
119
return get_scalar_type (
120
120
# pyre-ignore[16]: now assert all outputs is in tensor type.
121
121
get_program_output (program , plan_idx , output_idx ).scalar_type
122
122
)
123
123
124
124
125
125
def assert_valid_bundle (
126
- program : Program ,
126
+ program : core_schema . Program ,
127
127
bundled_config : BundledConfig ,
128
128
) -> None :
129
129
"""Check if the program and BundledConfig matches each other.
@@ -145,7 +145,7 @@ def assert_valid_bundle(
145
145
plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
146
146
bp_plan_id
147
147
]
148
- plan : ExecutionPlan = program .execution_plan [program_plan_id ]
148
+ plan : core_schema . ExecutionPlan = program .execution_plan [program_plan_id ]
149
149
150
150
# User does not provide testcases for current plan, skip it
151
151
if plan_test .method_name < plan .name :
@@ -167,7 +167,8 @@ def assert_valid_bundle(
167
167
# Check if the type of Program's output is supported
168
168
for index in range (len (plan .outputs )):
169
169
assert (
170
- type (get_program_output (program , program_plan_id , index )) == Tensor
170
+ type (get_program_output (program , program_plan_id , index ))
171
+ == core_schema .Tensor
171
172
), "Only supports program with output in Tensor type."
172
173
173
174
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -239,10 +240,10 @@ def assert_valid_bundle(
239
240
240
241
241
242
def create_bundled_program (
242
- program : Program ,
243
+ program : core_schema . Program ,
243
244
bundled_config : BundledConfig ,
244
- ) -> BundledProgram :
245
- """Create BundledProgram by bundling the given program and bundled_config together.
245
+ ) -> bp_schema . BundledProgram :
246
+ """Create bp_schema. BundledProgram by bundling the given program and bundled_config together.
246
247
247
248
Args:
248
249
program: The program to be bundled.
@@ -251,16 +252,16 @@ def create_bundled_program(
251
252
252
253
assert_valid_bundle (program , bundled_config )
253
254
254
- execution_plan_tests : List [BundledExecutionPlanTest ] = []
255
+ execution_plan_tests : List [bp_schema . BundledExecutionPlanTest ] = []
255
256
256
257
# Emit data and metadata of bundled tensor
257
258
for plan_test in bundled_config .execution_plan_tests :
258
- test_sets : List [BundledIOSet ] = []
259
+ test_sets : List [bp_schema . BundledIOSet ] = []
259
260
260
261
# emit I/O sets for each execution plan test
261
262
for i in range (len (plan_test .test_sets )):
262
- inputs : List [BundledValue ] = []
263
- expected_outputs : List [BundledValue ] = []
263
+ inputs : List [bp_schema . Value ] = []
264
+ expected_outputs : List [bp_schema . Value ] = []
264
265
265
266
cur_plan_test_inputs = plan_test .test_sets [i ].inputs
266
267
cur_plan_test_expected_outputs = plan_test .test_sets [i ].expected_outputs
@@ -285,19 +286,19 @@ def create_bundled_program(
285
286
expected_outputs ,
286
287
)
287
288
test_sets .append (
288
- BundledIOSet (inputs = inputs , expected_outputs = expected_outputs )
289
+ bp_schema . BundledIOSet (inputs = inputs , expected_outputs = expected_outputs )
289
290
)
290
291
291
292
# emit the whole execution plan test
292
293
execution_plan_tests .append (
293
- BundledExecutionPlanTest (
294
+ bp_schema . BundledExecutionPlanTest (
294
295
method_name = plan_test .method_name , test_sets = test_sets
295
296
)
296
297
)
297
298
298
299
program_bytes : bytes = _serialize_pte_binary (program )
299
300
300
- return BundledProgram (
301
+ return bp_schema . BundledProgram (
301
302
version = BUNDLED_PROGRAM_SCHEMA_VERSION ,
302
303
execution_plan_tests = execution_plan_tests ,
303
304
program = program_bytes ,
0 commit comments