Skip to content

Commit e91357d

Browse files
tom-armfacebook-github-bot
authored andcommitted
Adding model stats to aot_arm_compiler (#5816)
Summary: * Adds GenericModelEvaluator, which gathers metrics applicable to all models * Adds --evaluate option to enable gathering quantization metrics Signed-off-by: Tom Allsop <[email protected]> Change-Id: Ia9b591841f188870fa5e62d0568169498301393d Pull Request resolved: #5816 Reviewed By: mergennachin Differential Revision: D64047185 Pulled By: digantdesai fbshipit-source-id: b443b616e9092ee5c39ce9ec07bd0c0ef2aca04a
1 parent 97a1965 commit e91357d

File tree

3 files changed

+221
-6
lines changed

3 files changed

+221
-6
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import random
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator
13+
14+
random.seed(0)
15+
16+
# Create an input that is hard to compress
17+
COMPRESSION_RATIO_TEST = bytearray(random.getrandbits(8) for _ in range(1000000))
18+
19+
20+
def mocked_model_1(input: torch.Tensor) -> torch.Tensor:
21+
return torch.tensor([1.0, 2.0, 3.0, 4.0])
22+
23+
24+
def mocked_model_2(input: torch.Tensor) -> torch.Tensor:
25+
return torch.tensor([1.0, 2.0, 3.0, 3.0])
26+
27+
28+
class TestGenericModelEvaluator(unittest.TestCase):
29+
"""Tests the GenericModelEvaluator class."""
30+
31+
def test_get_model_error(self):
32+
example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
33+
evaluator = GenericModelEvaluator(
34+
"dummy_model",
35+
mocked_model_1,
36+
mocked_model_2,
37+
example_input,
38+
"tmp/output_tag0.tosa",
39+
)
40+
max_error, max_absolute_error, max_percentage_error, mae = (
41+
evaluator.get_model_error()
42+
)
43+
44+
self.assertEqual(max_error, 1.0)
45+
self.assertEqual(max_absolute_error, 1.0)
46+
self.assertEqual(max_percentage_error, 25.0)
47+
self.assertEqual(mae, 0.25)
48+
49+
def test_get_compression_ratio(self):
50+
with tempfile.NamedTemporaryFile(delete=True) as temp_bin:
51+
temp_bin.write(COMPRESSION_RATIO_TEST)
52+
53+
# As the size of the file is quite small we need to call flush()
54+
temp_bin.flush()
55+
temp_bin_name = temp_bin.name
56+
57+
example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
58+
evaluator = GenericModelEvaluator(
59+
"dummy_model",
60+
mocked_model_1,
61+
mocked_model_2,
62+
example_input,
63+
temp_bin_name,
64+
)
65+
66+
ratio = evaluator.get_compression_ratio()
67+
self.assertAlmostEqual(ratio, 1.0, places=2)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import tempfile
9+
import zipfile
10+
from typing import Optional, Tuple, Union
11+
12+
import torch
13+
14+
15+
class GenericModelEvaluator:
16+
def __init__(
17+
self,
18+
model_name: str,
19+
fp32_model: torch.nn.Module,
20+
int8_model: torch.nn.Module,
21+
example_input: Tuple[torch.Tensor],
22+
tosa_output_path: Optional[str],
23+
) -> None:
24+
self.model_name = model_name
25+
26+
self.fp32_model = fp32_model
27+
self.int8_model = int8_model
28+
self.example_input = example_input
29+
30+
if tosa_output_path:
31+
self.tosa_output_path = tosa_output_path
32+
else:
33+
self.tosa_output_path = None
34+
35+
def get_model_error(self) -> Union[float, float, float, float]:
36+
"""
37+
Returns the following metrics between the outputs of the FP32 and INT8 model:
38+
- Maximum error
39+
- Maximum absolute error
40+
- Maximum percentage error
41+
- Mean absolute error
42+
"""
43+
fp32_output = self.fp32_model(*self.example_input)
44+
int8_output = self.int8_model(*self.example_input)
45+
46+
difference = fp32_output - int8_output
47+
percentage_error = torch.div(difference, fp32_output) * 100
48+
49+
max_error = torch.max(difference).item()
50+
max_absolute_error = torch.max(torch.abs(difference)).item()
51+
max_percentage_error = torch.max(percentage_error).item()
52+
mean_absolute_error = torch.mean(torch.abs(difference).float()).item()
53+
54+
return max_error, max_absolute_error, max_percentage_error, mean_absolute_error
55+
56+
def get_compression_ratio(self) -> float:
57+
"""Compute the compression ratio of the outputted TOSA flatbuffer."""
58+
with tempfile.NamedTemporaryFile(delete=True, suffix=".zip") as temp_zip:
59+
with zipfile.ZipFile(
60+
temp_zip.name, "w", compression=zipfile.ZIP_DEFLATED
61+
) as f:
62+
f.write(self.tosa_output_path)
63+
64+
compression_ratio = os.path.getsize(
65+
self.tosa_output_path
66+
) / os.path.getsize(temp_zip.name)
67+
68+
return compression_ratio
69+
70+
def evaluate(self) -> dict[any]:
71+
max_error, max_absolute_error, max_percent_error, mean_absolute_error = (
72+
self.get_model_error()
73+
)
74+
output_metrics = {
75+
"name": self.model_name,
76+
"metrics": {
77+
"max_error": max_error,
78+
"max_absolute_error": max_absolute_error,
79+
"max_percentage_error": max_percent_error,
80+
"mean_absolute_error": mean_absolute_error,
81+
},
82+
}
83+
84+
if self.tosa_output_path:
85+
output_metrics["metrics"][
86+
"compression_ratio"
87+
] = self.get_compression_ratio()
88+
89+
return output_metrics

examples/arm/aot_arm_compiler.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@
88
# Example script for exporting simple models to flatbuffer
99

1010
import argparse
11+
import json
1112
import logging
1213
import os
13-
from typing import Optional
1414

15-
import torch
15+
from pathlib import Path
16+
from typing import Optional, Tuple
1617

18+
import torch
1719
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1820
from executorch.backends.arm.arm_partitioner import ArmPartitioner
1921
from executorch.backends.arm.quantizer.arm_quantizer import (
2022
ArmQuantizer,
2123
get_symmetric_quantization_config,
2224
)
25+
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator
2326

2427
from executorch.devtools.backend_debug import get_delegation_info
2528
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
@@ -151,6 +154,8 @@ def forward(self, x):
151154
"softmax": SoftmaxModule,
152155
}
153156

157+
evaluators = {}
158+
154159
targets = [
155160
"ethos-u55-32",
156161
"ethos-u55-64",
@@ -202,6 +207,37 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
202207
return spec_builder.build()
203208

204209

210+
def get_evaluator(model_name: str) -> GenericModelEvaluator:
211+
if model_name not in evaluators:
212+
return GenericModelEvaluator
213+
else:
214+
return evaluators[model_name]
215+
216+
217+
def evaluate_model(
218+
model_name: str,
219+
intermediates: str,
220+
model_fp32: torch.nn.Module,
221+
model_int8: torch.nn.Module,
222+
example_inputs: Tuple[torch.Tensor],
223+
):
224+
evaluator = get_evaluator(model_name)
225+
226+
# Get the path of the TOSA flatbuffer that is dumped
227+
intermediates_path = Path(intermediates)
228+
tosa_paths = list(intermediates_path.glob("*.tosa"))
229+
230+
init_evaluator = evaluator(
231+
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
232+
)
233+
234+
quant_metrics = init_evaluator.evaluate()
235+
output_json_path = intermediates_path / "quant_metrics.json"
236+
237+
with output_json_path.open("w") as json_file:
238+
json.dump(quant_metrics, json_file)
239+
240+
205241
def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None):
206242
graph_module = edge.exported_program().graph_module
207243
delegation_info = get_delegation_info(graph_module)
@@ -242,6 +278,14 @@ def get_args():
242278
choices=targets,
243279
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
244280
)
281+
parser.add_argument(
282+
"-e",
283+
"--evaluate",
284+
action="store_true",
285+
required=False,
286+
default=False,
287+
help="Flag for running evaluation of the model.",
288+
)
245289
parser.add_argument(
246290
"-q",
247291
"--quantize",
@@ -275,11 +319,11 @@ def get_args():
275319
help="Location for outputs, if not the default of cwd.",
276320
)
277321
args = parser.parse_args()
278-
return args
279322

280-
281-
if __name__ == "__main__":
282-
args = get_args()
323+
if args.evaluate and (args.quantize is None or args.intermediates is None):
324+
raise RuntimeError(
325+
"--evaluate requires --quantize and --intermediates to be enabled."
326+
)
283327

284328
if args.debug:
285329
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
@@ -302,16 +346,26 @@ def get_args():
302346
):
303347
raise RuntimeError(f"Model {args.model_name} cannot be delegated.")
304348

349+
return args
350+
351+
352+
if __name__ == "__main__":
353+
args = get_args()
354+
305355
# Pick model from one of the supported lists
306356
model, example_inputs = get_model_and_inputs_from_name(args.model_name)
307357
model = model.eval()
308358

359+
model_fp32 = model
360+
309361
# pre-autograd export. eventually this will become torch.export
310362
model = torch.export.export_for_training(model, example_inputs).module()
311363

312364
# Quantize if required
365+
model_int8 = None
313366
if args.quantize:
314367
model = quantize(model, example_inputs)
368+
model_int8 = model
315369

316370
edge = export_to_edge(
317371
model,
@@ -361,3 +415,8 @@ def get_args():
361415
output_name = os.path.join(args.output, output_name)
362416

363417
save_pte_program(exec_prog, output_name)
418+
419+
if args.evaluate:
420+
evaluate_model(
421+
args.model_name, args.intermediates, model_fp32, model_int8, example_inputs
422+
)

0 commit comments

Comments
 (0)