Skip to content

Commit 17ad8d3

Browse files
authored
Fix type handling for output types from TOSA reference model (#6660)
Change-Id: I80953a699e4861b901af4b2fb17d47d3d7efcedd Signed-off-by: Per Åstrand <[email protected]>
1 parent 03b1ef2 commit 17ad8d3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

backends/arm/test/runner_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,16 +448,21 @@ def run_tosa_ref_model(
448448
), "There are no quantization parameters, check output parameters"
449449
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
450450

451+
if tosa_ref_output.dtype == np.double:
452+
tosa_ref_output = tosa_ref_output.astype("float32")
453+
451454
# tosa_output is a numpy array, convert to torch tensor for comparison
452-
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))
455+
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output))
453456

454457
return tosa_ref_outputs
455458

456459

457460
def prep_data_for_save(
458461
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
459462
):
460-
data_np = np.array(data.detach(), order="C").astype(np.float32)
463+
data_np = np.array(data.detach(), order="C").astype(
464+
f"{data.dtype}".replace("torch.", "")
465+
)
461466

462467
if is_quantized:
463468
assert quant_param.node_name in input_name, (

0 commit comments

Comments
 (0)