Skip to content

Add debug feature to deserialize TOSA fb on dump_artifact #2560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import tempfile
import unittest

import torch
from executorch.backends.arm.test.test_models import TosaProfile
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Linear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int = 3,
bias: bool = True,
):
super().__init__()
self.inputs = (torch.ones(5, 10, 25, in_features),)
self.fc = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
)

def get_inputs(self):
return self.inputs

def forward(self, x):
return self.fc(x)


class TestDumpPartitionedArtifact(unittest.TestCase):
def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None):
(
ArmTester(
module,
inputs=module.get_inputs(),
profile=TosaProfile.MI,
backend=ArmBackendSelector.TOSA,
)
.export()
.to_edge()
.partition()
.dump_artifact(dump_file)
.dump_artifact()
)

def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None):
(
ArmTester(
module,
inputs=module.get_inputs(),
profile=TosaProfile.BI,
backend=ArmBackendSelector.TOSA,
)
.quantize()
.export()
.to_edge()
.partition()
.dump_artifact(dump_file)
.dump_artifact()
)

def _is_tosa_marker_in_file(self, tmp_file):
for line in open(tmp_file).readlines():
if "'name': 'main'" in line:
return True
return False

def test_MI_artifact(self):
model = Linear(20, 30)
tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt")
self._tosa_MI_pipeline(model, dump_file=tmp_file)
assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
if self._is_tosa_marker_in_file(tmp_file):
return # Implicit pass test
self.fail("File does not contain TOSA dump!")

def test_BI_artifact(self):
model = Linear(20, 30)
tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt")
self._tosa_BI_pipeline(model, dump_file=tmp_file)
assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
if self._is_tosa_marker_in_file(tmp_file):
return # Implicit pass test
self.fail("File does not contain TOSA dump!")
27 changes: 27 additions & 0 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ class ArmBackendSelector(Enum):
ETHOS_U55 = "ethos-u55"


class Partition(Partition):
def dump_artifact(self, path_to_dump: Optional[str]):
super().dump_artifact(path_to_dump)
from pprint import pformat

to_print = None
for spec in self.graph_module.lowered_module_0.compile_specs:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an API to or other 'nicer' way to access

graph_module.lowered_module_0.compile_specs

and

graph_module.lowered_module_0.processed_bytes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I am aware of, cc @cccclai

if spec.key == "output_format":
if spec.value == b"tosa":
tosa_fb = self.graph_module.lowered_module_0.processed_bytes
to_print = TosaTestUtils.dbg_tosa_fb_to_json(tosa_fb)
to_print = pformat(to_print, compact=True, indent=1)
to_print = f"\n TOSA deserialized: \n{to_print}"
elif spec.value == b"vela":
vela_cmd_stream = self.graph_module.lowered_module_0.processed_bytes
to_print = str(vela_cmd_stream)
to_print = f"\n Vela command stream: \n{to_print}"
break
assert to_print is not None, "No TOSA nor Vela compile spec found"

if path_to_dump:
with open(path_to_dump, "a") as fp:
fp.write(to_print)
else:
print(to_print)


class ArmTester(Tester):
def __init__(
self,
Expand Down
49 changes: 22 additions & 27 deletions backends/arm/test/tosautil/tosa_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import subprocess
import tempfile

from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -52,49 +52,43 @@ def __init__(
self.intermediate_path
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"

def dbg_dump_readble_tosa_file(self) -> None:
@staticmethod
def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
"""
This function is used to dump the TOSA buffer to a human readable
format, using flatc.
It requires the following files to be present on disk:
1) output.tosa (in self.intermediate_path, produced by arm_backend.py)
2) ./backends/arm/third-party/serialization_lib/schema/tosa.fbs.

It is used for debugging purposes.

Output from this is a file called output.json, located in
self.intermediate_path.

Todo:
* I'd prefer if this function didn't use files on disk...
* Check if we can move this function to dump_artificat() thingy...
This function is used to dump the TOSA flatbuffer to a human readable
format, using flatc. It is used for debugging purposes.
"""

tosa_input_file = self.intermediate_path + "/output.tosa"
tmp = tempfile.mkdtemp()
tosa_input_file = os.path.join(tmp, "output.tosa")
with open(tosa_input_file, "wb") as f:
f.write(tosa_fb)

tosa_schema_file = (
"./backends/arm/third-party/serialization_lib/schema/tosa.fbs"
)

assert os.path.exists(
tosa_schema_file
), f"tosa_schema_file: {tosa_schema_file} does not exist"
assert os.path.exists(
tosa_input_file
), f"tosa_input_file: {tosa_input_file} does not exist"
assert shutil.which("flatc") is not None

assert shutil.which("flatc") is not None
cmd_flatc = [
"flatc",
"--json",
"--strict-json",
"-o",
self.intermediate_path,
tmp,
"--raw-binary",
"-t",
tosa_schema_file,
"--",
tosa_input_file,
]
self._run_cmd(cmd_flatc)
return
TosaTestUtils._run_cmd(cmd_flatc)
with open(os.path.join(tmp, "output.json"), "r") as f:
json_out = json.load(f)

return json_out

def run_tosa_ref_model(
self,
Expand Down Expand Up @@ -191,7 +185,7 @@ def run_tosa_ref_model(
shutil.which(self.tosa_ref_model_path) is not None
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
cmd_ref_model = [self.tosa_ref_model_path, "--test_desc", desc_file_path]
self._run_cmd(cmd_ref_model)
TosaTestUtils._run_cmd(cmd_ref_model)

# Load desc.json, just to get the name of the output file above
with open(desc_file_path) as f:
Expand All @@ -214,7 +208,8 @@ def run_tosa_ref_model(

return tosa_ref_output

def _run_cmd(self, cmd: List[str]) -> None:
@staticmethod
def _run_cmd(cmd: List[str]) -> None:
"""
Run a command and check for errors.

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/third-party/serialization_lib
Submodule serialization_lib updated from bd8c52 to 187af0