7
7
import os
8
8
import tempfile
9
9
import zipfile
10
- from typing import Any , Optional , Tuple
10
+ from collections import defaultdict
11
+ from typing import Optional , Tuple
11
12
12
13
import torch
13
14
14
15
16
+ def flatten_args (args ) -> tuple | list :
17
+ flattened_args : list = []
18
+ if isinstance (args , torch .Tensor ):
19
+ return [args ]
20
+
21
+ for arg in args :
22
+ if isinstance (arg , (tuple , list )):
23
+ flattened_args .extend (arg )
24
+ else :
25
+ flattened_args .append (arg )
26
+
27
+ return tuple (flattened_args )
28
+
29
+
15
30
class GenericModelEvaluator :
16
31
def __init__ (
17
32
self ,
@@ -32,31 +47,34 @@ def __init__(
32
47
else :
33
48
self .tosa_output_path = None
34
49
35
- def get_model_error (self ) -> tuple [ float , float , float , float ] :
50
+ def get_model_error (self ) -> defaultdict :
36
51
"""
37
- Returns the following metrics between the outputs of the FP32 and INT8 model:
52
+ Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
38
53
- Maximum error
39
54
- Maximum absolute error
40
55
- Maximum percentage error
41
56
- Mean absolute error
42
57
"""
43
- fp32_output = self .fp32_model (* self .example_input )
44
- int8_output = self .int8_model (* self .example_input )
45
-
46
- difference = fp32_output - int8_output
47
- percentage_error = torch .div (difference , fp32_output ) * 100
48
-
49
- max_error = torch .max (difference ).item ()
50
- max_absolute_error = torch .max (torch .abs (difference )).item ()
51
- max_percentage_error = torch .max (percentage_error ).item ()
52
- mean_absolute_error = torch .mean (torch .abs (difference ).float ()).item ()
53
-
54
- return (
55
- float (max_error ),
56
- float (max_absolute_error ),
57
- float (max_percentage_error ),
58
- float (mean_absolute_error ),
59
- )
58
+ fp32_outputs = flatten_args (self .fp32_model (* self .example_input ))
59
+ int8_outputs = flatten_args (self .int8_model (* self .example_input ))
60
+
61
+ model_error_dict = defaultdict (list )
62
+
63
+ for fp32_output , int8_output in zip (fp32_outputs , int8_outputs ):
64
+ difference = fp32_output - int8_output
65
+ percentage_error = torch .div (difference , fp32_output ) * 100
66
+ model_error_dict ["max_error" ].append (torch .max (difference ).item ())
67
+ model_error_dict ["max_absolute_error" ].append (
68
+ torch .max (torch .abs (difference )).item ()
69
+ )
70
+ model_error_dict ["max_percentage_error" ].append (
71
+ torch .max (percentage_error ).item ()
72
+ )
73
+ model_error_dict ["mean_absolute_error" ].append (
74
+ torch .mean (torch .abs (difference ).float ()).item ()
75
+ )
76
+
77
+ return model_error_dict
60
78
61
79
def get_compression_ratio (self ) -> float :
62
80
"""Compute the compression ratio of the outputted TOSA flatbuffer."""
@@ -72,19 +90,10 @@ def get_compression_ratio(self) -> float:
72
90
73
91
return compression_ratio
74
92
75
- def evaluate (self ) -> dict [str , Any ]:
76
- max_error , max_absolute_error , max_percent_error , mean_absolute_error = (
77
- self .get_model_error ()
78
- )
79
- output_metrics = {
80
- "name" : self .model_name ,
81
- "metrics" : {
82
- "max_error" : max_error ,
83
- "max_absolute_error" : max_absolute_error ,
84
- "max_percentage_error" : max_percent_error ,
85
- "mean_absolute_error" : mean_absolute_error ,
86
- },
87
- }
93
+ def evaluate (self ) -> dict [any ]:
94
+ model_error_dict = self .get_model_error ()
95
+
96
+ output_metrics = {"name" : self .model_name , "metrics" : dict (model_error_dict )}
88
97
89
98
if self .tosa_output_path :
90
99
# We know output_metrics["metrics"] is list since we just defined it, safe to ignore.
0 commit comments