Skip to content

Adding model stats to aot_arm_compiler #5816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions backends/arm/test/misc/test_model_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random
import tempfile
import unittest

import torch
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator

random.seed(0)

# Create an input that is hard to compress
COMPRESSION_RATIO_TEST = bytearray(random.getrandbits(8) for _ in range(1000000))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.rand() fp32 dtype -> save it as bytesio?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_get_compression_ratio tests a file on a filesystem, so not sure if bytesio works here (I understand it's for writing in memory). Unless I'm missing something?

I can look into using torch.rand() instead.



def mocked_model_1(input: torch.Tensor) -> torch.Tensor:
return torch.tensor([1.0, 2.0, 3.0, 4.0])


def mocked_model_2(input: torch.Tensor) -> torch.Tensor:
return torch.tensor([1.0, 2.0, 3.0, 3.0])


class TestGenericModelEvaluator(unittest.TestCase):
"""Tests the GenericModelEvaluator class."""

def test_get_model_error(self):
example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
evaluator = GenericModelEvaluator(
"dummy_model",
mocked_model_1,
mocked_model_2,
example_input,
"tmp/output_tag0.tosa",
)
max_error, max_absolute_error, max_percentage_error, mae = (
evaluator.get_model_error()
)

self.assertEqual(max_error, 1.0)
self.assertEqual(max_absolute_error, 1.0)
self.assertEqual(max_percentage_error, 25.0)
self.assertEqual(mae, 0.25)

def test_get_compression_ratio(self):
with tempfile.NamedTemporaryFile(delete=True) as temp_bin:
temp_bin.write(COMPRESSION_RATIO_TEST)

# As the size of the file is quite small we need to call flush()
temp_bin.flush()
temp_bin_name = temp_bin.name

example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
evaluator = GenericModelEvaluator(
"dummy_model",
mocked_model_1,
mocked_model_2,
example_input,
temp_bin_name,
)

ratio = evaluator.get_compression_ratio()
self.assertAlmostEqual(ratio, 1.0, places=2)
89 changes: 89 additions & 0 deletions backends/arm/util/arm_model_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import tempfile
import zipfile
from typing import Optional, Tuple, Union

import torch


class GenericModelEvaluator:
def __init__(
self,
model_name: str,
fp32_model: torch.nn.Module,
int8_model: torch.nn.Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: Optional[str],
) -> None:
self.model_name = model_name

self.fp32_model = fp32_model
self.int8_model = int8_model
self.example_input = example_input

if tosa_output_path:
self.tosa_output_path = tosa_output_path
else:
self.tosa_output_path = None

def get_model_error(self) -> Union[float, float, float, float]:
"""
Returns the following metrics between the outputs of the FP32 and INT8 model:
- Maximum error
- Maximum absolute error
- Maximum percentage error
- Mean absolute error
"""
fp32_output = self.fp32_model(*self.example_input)
int8_output = self.int8_model(*self.example_input)

difference = fp32_output - int8_output
percentage_error = torch.div(difference, fp32_output) * 100

max_error = torch.max(difference).item()
max_absolute_error = torch.max(torch.abs(difference)).item()
max_percentage_error = torch.max(percentage_error).item()
mean_absolute_error = torch.mean(torch.abs(difference).float()).item()

return max_error, max_absolute_error, max_percentage_error, mean_absolute_error

def get_compression_ratio(self) -> float:
"""Compute the compression ratio of the outputted TOSA flatbuffer."""
with tempfile.NamedTemporaryFile(delete=True, suffix=".zip") as temp_zip:
with zipfile.ZipFile(
temp_zip.name, "w", compression=zipfile.ZIP_DEFLATED
) as f:
f.write(self.tosa_output_path)

compression_ratio = os.path.getsize(
self.tosa_output_path
) / os.path.getsize(temp_zip.name)

return compression_ratio

def evaluate(self) -> dict[any]:
max_error, max_absolute_error, max_percent_error, mean_absolute_error = (
self.get_model_error()
)
output_metrics = {
"name": self.model_name,
"metrics": {
"max_error": max_error,
"max_absolute_error": max_absolute_error,
"max_percentage_error": max_percent_error,
"mean_absolute_error": mean_absolute_error,
},
}

if self.tosa_output_path:
output_metrics["metrics"][
"compression_ratio"
] = self.get_compression_ratio()

return output_metrics
71 changes: 65 additions & 6 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@
# Example script for exporting simple models to flatbuffer

import argparse
import json
import logging
import os
from typing import Optional

import torch
from pathlib import Path
from typing import Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.arm_partitioner import ArmPartitioner
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator

from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
Expand Down Expand Up @@ -151,6 +154,8 @@ def forward(self, x):
"softmax": SoftmaxModule,
}

evaluators = {}

targets = [
"ethos-u55-32",
"ethos-u55-64",
Expand Down Expand Up @@ -202,6 +207,37 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
return spec_builder.build()


def get_evaluator(model_name: str) -> GenericModelEvaluator:
if model_name not in evaluators:
return GenericModelEvaluator
else:
return evaluators[model_name]


def evaluate_model(
model_name: str,
intermediates: str,
model_fp32: torch.nn.Module,
model_int8: torch.nn.Module,
example_inputs: Tuple[torch.Tensor],
):
evaluator = get_evaluator(model_name)

# Get the path of the TOSA flatbuffer that is dumped
intermediates_path = Path(intermediates)
tosa_paths = list(intermediates_path.glob("*.tosa"))

init_evaluator = evaluator(
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
)

quant_metrics = init_evaluator.evaluate()
output_json_path = intermediates_path / "quant_metrics.json"

with output_json_path.open("w") as json_file:
json.dump(quant_metrics, json_file)


def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None):
graph_module = edge.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
Expand Down Expand Up @@ -242,6 +278,14 @@ def get_args():
choices=targets,
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
)
parser.add_argument(
"-e",
"--evaluate",
action="store_true",
required=False,
default=False,
help="Flag for running evaluation of the model.",
)
parser.add_argument(
"-q",
"--quantize",
Expand Down Expand Up @@ -275,11 +319,11 @@ def get_args():
help="Location for outputs, if not the default of cwd.",
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = get_args()
if args.evaluate and (args.quantize is None or args.intermediates is None):
raise RuntimeError(
"--evaluate requires --quantize and --intermediates to be enabled."
)

if args.debug:
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
Expand All @@ -302,16 +346,26 @@ def get_args():
):
raise RuntimeError(f"Model {args.model_name} cannot be delegated.")

return args


if __name__ == "__main__":
args = get_args()

# Pick model from one of the supported lists
model, example_inputs = get_model_and_inputs_from_name(args.model_name)
model = model.eval()

model_fp32 = model

# pre-autograd export. eventually this will become torch.export
model = torch.export.export_for_training(model, example_inputs).module()

# Quantize if required
model_int8 = None
if args.quantize:
model = quantize(model, example_inputs)
model_int8 = model

edge = export_to_edge(
model,
Expand Down Expand Up @@ -361,3 +415,8 @@ def get_args():
output_name = os.path.join(args.output, output_name)

save_pte_program(exec_prog, output_name)

if args.evaluate:
evaluate_model(
args.model_name, args.intermediates, model_fp32, model_int8, example_inputs
)
Loading