@@ -115,50 +115,53 @@ def _get_input_quantization_params(
115
115
return quant_params
116
116
117
117
118
- def _get_output_node (program : ExportedProgram ) -> Node :
118
+ def _get_output_nodes (program : ExportedProgram ) -> list [ Node ] :
119
119
"""
120
120
Get output node to this model.
121
121
122
122
Args:
123
- program (ExportedProgram): The program to get output node from.
123
+ program (ExportedProgram): The program to get the output nodes from.
124
124
Returns:
125
- The node that is the output of 'program'.
125
+ The nodes that are the outputs of the 'program'.
126
126
"""
127
-
127
+ output_nodes = []
128
128
for node in program .graph .nodes :
129
129
if node .op == "output" :
130
- return node
131
- raise RuntimeError ("No output node found." )
130
+ for output in node .args [0 ]:
131
+ output_nodes .append (output )
132
+ if len (output_nodes ) == 0 :
133
+ raise RuntimeError ("No output nodes found." )
134
+ else :
135
+ return output_nodes
132
136
133
137
134
138
def _get_output_quantization_params (
135
- program : ExportedProgram , output_node : Node
136
- ) -> Optional [QuantizationParams ]:
139
+ output_nodes : list [ Node ],
140
+ ) -> List [QuantizationParams ]:
137
141
"""
138
142
Get output QuantizationParams from a program.
139
143
Args:
140
- program (ExportedProgram): The program to get output quantization parameters from.
144
+ output_nodes (list(Node)): A list of output nodes to get output quantization parameters from.
141
145
Returns:
142
146
QuantizationParams: The found quantization parameters.
143
147
Raises:
144
148
RuntimeError if no output quantization parameters are found.
145
149
"""
146
-
147
- quant_params = None
148
- for node in program .graph .nodes :
149
- if (
150
- node .target == torch .ops .quantized_decomposed .dequantize_per_tensor .default
151
- and node == output_node .args [0 ][0 ]
152
- ):
153
- quant_params = QuantizationParams (
154
- node_name = node .args [0 ].name ,
155
- scale = node .args [1 ],
156
- zp = node .args [2 ],
157
- qmin = node .args [3 ],
158
- qmax = node .args [4 ],
159
- dtype = node .args [5 ],
150
+ quant_params = []
151
+ for node in output_nodes :
152
+ if node .target == torch .ops .quantized_decomposed .dequantize_per_tensor .default :
153
+ quant_params .append (
154
+ QuantizationParams (
155
+ node_name = node .args [0 ].name ,
156
+ scale = node .args [1 ],
157
+ zp = node .args [2 ],
158
+ qmin = node .args [3 ],
159
+ qmax = node .args [4 ],
160
+ dtype = node .args [5 ],
161
+ )
160
162
)
161
- break # break early, there's only one output node
163
+ if len (quant_params ) == 0 :
164
+ raise RuntimeError ("No Quantization parameters not found in exported model." )
162
165
return quant_params
163
166
164
167
@@ -211,7 +214,7 @@ def __init__(
211
214
self .input_names : list [str ] = None
212
215
self .output_name : str = None
213
216
self .qp_input : list [QuantizationParams ] = None
214
- self .qp_output : QuantizationParams = None
217
+ self .qp_output : list [ QuantizationParams ] = None
215
218
self .timeout = 480
216
219
self .target_board : str = None
217
220
@@ -226,19 +229,17 @@ def init_run(
226
229
):
227
230
228
231
self .input_names = _get_input_names (edge_program )
229
- self .output_node = _get_output_node (exported_program )
230
- self . output_name = self . output_node . name
232
+ self .output_nodes = _get_output_nodes (exported_program )
233
+
231
234
self .is_quantized = is_quantized
232
235
self .target_board = target_board
233
236
234
237
if is_quantized :
235
238
self .qp_input = _get_input_quantization_params (exported_program )
236
- self .qp_output = _get_output_quantization_params (
237
- exported_program , self .output_node
238
- )
239
+ self .qp_output = _get_output_quantization_params (self .output_nodes )
239
240
else :
240
241
self .qp_input = [None ] * len (self .input_names )
241
- self .qp_output = None
242
+ self .qp_output = [ None ] * len ( self . output_nodes )
242
243
243
244
self ._has_init_run = True
244
245
@@ -265,7 +266,7 @@ def run_corstone(
265
266
save_bytes (self .intermediate_path , data , False , input_name , quant_param )
266
267
267
268
out_path = os .path .join (self .intermediate_path , "out" )
268
- out_path_with_suffix = out_path + "-0.bin"
269
+
269
270
input_paths = []
270
271
for name in self .input_names :
271
272
input_paths .append (
@@ -281,6 +282,7 @@ def run_corstone(
281
282
), f"Did not find build arm_executor_runner in path { elf_path } , run setup_testing.sh?"
282
283
283
284
cmd_line = f"executor_runner -m { pte_path } -o { out_path } "
285
+
284
286
for input_path in input_paths :
285
287
cmd_line += f" -i { input_path } "
286
288
@@ -362,11 +364,14 @@ def run_corstone(
362
364
raise RuntimeError (
363
365
f"Corstone simulation failed:\n cmd: { command_args [self .target_board ]} \n , log: \n { result_stdout } \n { result .stderr .decode ()} "
364
366
)
365
-
366
- tosa_ref_output = np .fromfile (out_path_with_suffix , dtype = np .float32 )
367
- output_shape = self .output_node .args [0 ][0 ].meta ["val" ].shape
368
- tosa_ref_output = torch .from_numpy (tosa_ref_output ).reshape (output_shape )
369
- return tosa_ref_output
367
+ output_np = []
368
+ for i , node in enumerate (self .output_nodes ):
369
+ tosa_ref_output = np .fromfile (
370
+ os .path .join (self .intermediate_path , f"out-{ i } .bin" ), dtype = np .float32
371
+ )
372
+ output_shape = node .meta ["val" ].shape
373
+ output_np .append (torch .from_numpy (tosa_ref_output ).reshape (output_shape ))
374
+ return tuple (output_np )
370
375
371
376
def run_tosa_graph (
372
377
self , graph : TosaGraph , inputs : list [np .ndarray ] | list [torch .Tensor ]
0 commit comments