Skip to content

Commit 3da4b5d

Browse files
SaoirseARMfreddan80
authored andcommitted
Update model evaluator to check multiple outputs
1 parent ee14ad0 commit 3da4b5d

File tree

3 files changed

+55
-42
lines changed

3 files changed

+55
-42
lines changed

backends/arm/test/misc/test_model_evaluator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ def test_get_model_error(self):
3737
example_input,
3838
"tmp/output_tag0.tosa",
3939
)
40-
max_error, max_absolute_error, max_percentage_error, mae = (
41-
evaluator.get_model_error()
42-
)
4340

44-
self.assertEqual(max_error, 1.0)
45-
self.assertEqual(max_absolute_error, 1.0)
46-
self.assertEqual(max_percentage_error, 25.0)
47-
self.assertEqual(mae, 0.25)
41+
model_error_dict = evaluator.get_model_error()
42+
43+
self.assertEqual(model_error_dict["max_error"], [1.0])
44+
self.assertEqual(model_error_dict["max_absolute_error"], [1.0])
45+
self.assertEqual(model_error_dict["max_percentage_error"], [25.0])
46+
self.assertEqual(model_error_dict["mean_absolute_error"], [0.25])
4847

4948
def test_get_compression_ratio(self):
5049
with tempfile.NamedTemporaryFile(delete=True) as temp_bin:

backends/arm/util/arm_model_evaluator.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,26 @@
77
import os
88
import tempfile
99
import zipfile
10-
from typing import Any, Optional, Tuple
10+
from collections import defaultdict
11+
from typing import Optional, Tuple
1112

1213
import torch
1314

1415

16+
def flatten_args(args) -> tuple | list:
17+
flattened_args: list = []
18+
if isinstance(args, torch.Tensor):
19+
return [args]
20+
21+
for arg in args:
22+
if isinstance(arg, (tuple, list)):
23+
flattened_args.extend(arg)
24+
else:
25+
flattened_args.append(arg)
26+
27+
return tuple(flattened_args)
28+
29+
1530
class GenericModelEvaluator:
1631
def __init__(
1732
self,
@@ -32,31 +47,34 @@ def __init__(
3247
else:
3348
self.tosa_output_path = None
3449

35-
def get_model_error(self) -> tuple[float, float, float, float]:
50+
def get_model_error(self) -> defaultdict:
3651
"""
37-
Returns the following metrics between the outputs of the FP32 and INT8 model:
52+
Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
3853
- Maximum error
3954
- Maximum absolute error
4055
- Maximum percentage error
4156
- Mean absolute error
4257
"""
43-
fp32_output = self.fp32_model(*self.example_input)
44-
int8_output = self.int8_model(*self.example_input)
45-
46-
difference = fp32_output - int8_output
47-
percentage_error = torch.div(difference, fp32_output) * 100
48-
49-
max_error = torch.max(difference).item()
50-
max_absolute_error = torch.max(torch.abs(difference)).item()
51-
max_percentage_error = torch.max(percentage_error).item()
52-
mean_absolute_error = torch.mean(torch.abs(difference).float()).item()
53-
54-
return (
55-
float(max_error),
56-
float(max_absolute_error),
57-
float(max_percentage_error),
58-
float(mean_absolute_error),
59-
)
58+
fp32_outputs = flatten_args(self.fp32_model(*self.example_input))
59+
int8_outputs = flatten_args(self.int8_model(*self.example_input))
60+
61+
model_error_dict = defaultdict(list)
62+
63+
for fp32_output, int8_output in zip(fp32_outputs, int8_outputs):
64+
difference = fp32_output - int8_output
65+
percentage_error = torch.div(difference, fp32_output) * 100
66+
model_error_dict["max_error"].append(torch.max(difference).item())
67+
model_error_dict["max_absolute_error"].append(
68+
torch.max(torch.abs(difference)).item()
69+
)
70+
model_error_dict["max_percentage_error"].append(
71+
torch.max(percentage_error).item()
72+
)
73+
model_error_dict["mean_absolute_error"].append(
74+
torch.mean(torch.abs(difference).float()).item()
75+
)
76+
77+
return model_error_dict
6078

6179
def get_compression_ratio(self) -> float:
6280
"""Compute the compression ratio of the outputted TOSA flatbuffer."""
@@ -72,19 +90,10 @@ def get_compression_ratio(self) -> float:
7290

7391
return compression_ratio
7492

75-
def evaluate(self) -> dict[str, Any]:
76-
max_error, max_absolute_error, max_percent_error, mean_absolute_error = (
77-
self.get_model_error()
78-
)
79-
output_metrics = {
80-
"name": self.model_name,
81-
"metrics": {
82-
"max_error": max_error,
83-
"max_absolute_error": max_absolute_error,
84-
"max_percentage_error": max_percent_error,
85-
"mean_absolute_error": mean_absolute_error,
86-
},
87-
}
93+
def evaluate(self) -> dict[any]:
94+
model_error_dict = self.get_model_error()
95+
96+
output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)}
8897

8998
if self.tosa_output_path:
9099
# We know output_metrics["metrics"] is list since we just defined it, safe to ignore.

examples/arm/aot_arm_compiler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,11 @@ def get_args():
328328
)
329329
args = parser.parse_args()
330330

331-
if args.evaluate and (args.quantize is None or args.intermediates is None):
331+
if args.evaluate and (
332+
args.quantize is None or args.intermediates is None or (not args.delegate)
333+
):
332334
raise RuntimeError(
333-
"--evaluate requires --quantize and --intermediates to be enabled."
335+
"--evaluate requires --quantize, --intermediates and --delegate to be enabled."
334336
)
335337

336338
if args.debug:
@@ -378,6 +380,9 @@ def get_args():
378380
# Wrap quantized model back into an exported_program
379381
exported_program = torch.export.export_for_training(model, example_inputs)
380382

383+
if args.intermediates:
384+
os.makedirs(args.intermediates, exist_ok=True)
385+
381386
if args.delegate:
382387
# As we can target multiple output encodings from ArmBackend, one must
383388
# be specified.

0 commit comments

Comments
 (0)