Skip to content

Commit eae61f7

Browse files
perfreddan80
authored andcommitted
Allow TOSA tests to not have quant info
Quantized models might have output without quantization parameters attached to them. The assert for parameters not being None are removed and handled in order to allow for that case. numpy transpose is removed in favor of torch.permute to keep the type of the output after the operation. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I0e404062154cefa39f18b5706d72d19cac0e6d73
1 parent 224902c commit eae61f7

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

backends/arm/test/runner_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _get_output_node(program: ExportedProgram) -> Node:
127127

128128
def _get_output_quantization_params(
129129
program: ExportedProgram, output_node: Node
130-
) -> QuantizationParams:
130+
) -> Optional[QuantizationParams]:
131131
"""
132132
Get output QuantizationParams from a program.
133133
Args:
@@ -153,8 +153,6 @@ def _get_output_quantization_params(
153153
dtype=node.args[5],
154154
)
155155
break # break early, there's only one output node
156-
if quant_params is None:
157-
raise RuntimeError("No Quantization parameters not found in exported model.")
158156
return quant_params
159157

160158

@@ -485,13 +483,17 @@ def run_tosa_ref_model(
485483
if tosa_ref_output.dtype == np.int8:
486484
tosa_ref_output = tosa_ref_output.astype(np.int32)
487485
quant_param = self.qp_output
488-
assert (
489-
quant_param is not None
490-
), "There are no quantization parameters, check output parameters"
491-
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
486+
if quant_param is not None:
487+
# I.e. bool output is possible for quantized models
488+
tosa_ref_output = (
489+
tosa_ref_output - quant_param.zp
490+
) * quant_param.scale
492491

493492
if tosa_ref_output.dtype == np.double:
494493
tosa_ref_output = tosa_ref_output.astype("float32")
494+
elif tosa_ref_output.dtype == bool:
495+
# retain the bool output though for boolean related comparisons
496+
tosa_ref_output = tosa_ref_output.astype("bool")
495497

496498
# tosa_output is a numpy array, convert to torch tensor for comparison
497499
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output))

backends/arm/test/tester/arm_tester.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import executorch.backends.xnnpack.test.tester.tester as tester
1414

15-
import numpy as np
1615
import serializer.tosa_serializer as ts
1716

1817
import torch.fx
@@ -319,12 +318,15 @@ def run_method_and_compare_outputs(
319318
target_board,
320319
)
321320

321+
quantization_scale = None
322322
if is_quantized:
323323
reference_stage = self.stages[self.stage_name(tester.Quantize)]
324-
quantization_scale = self.runner_util.qp_output.scale
324+
# bool output is quantized with none quantized output so allow
325+
# self.runner_util.qp_output to be none
326+
if self.runner_util.qp_output is not None:
327+
quantization_scale = self.runner_util.qp_output.scale
325328
else:
326329
reference_stage = self.stages[self.stage_name(InitialModel)]
327-
quantization_scale = None
328330

329331
logger.info(
330332
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
@@ -504,7 +506,7 @@ def transpose_data_format(
504506
inputs_transposed = list(data)
505507
for i in range(len(data)):
506508
if hasattr(data[i], "shape") and len(data[i].shape) == 4:
507-
inputs_transposed[i] = np.transpose(data[i], dim_order)
509+
inputs_transposed[i] = torch.permute(data[i], dim_order)
508510
return tuple(inputs_transposed)
509511

510512
def _compare_outputs(

0 commit comments

Comments
 (0)