Skip to content

Commit bde6b53

Browse files
committed
Add debug feature to deserialize TOSA fb on dump_artifact()
The deserialized, human readable, TOSA fb is appended to the GraphModule print that is output with dump_artifact() on the Partition stage. Also, bump the TOSA serialization lib SHA, to avoid verbose warning messages. Change-Id: Ibb3120993d75293824f2ccb5a1b3981db64a2354 Signed-off-by: Fredrik Knutsson <[email protected]>
1 parent a41ac1c commit bde6b53

File tree

4 files changed

+146
-28
lines changed

4 files changed

+146
-28
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import os
9+
import tempfile
10+
import unittest
11+
12+
import torch
13+
from executorch.backends.arm.test.test_models import TosaProfile
14+
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester
15+
16+
logger = logging.getLogger(__name__)
17+
logger.setLevel(logging.INFO)
18+
19+
20+
class Linear(torch.nn.Module):
21+
def __init__(
22+
self,
23+
in_features: int,
24+
out_features: int = 3,
25+
bias: bool = True,
26+
):
27+
super().__init__()
28+
self.inputs = (torch.ones(5, 10, 25, in_features),)
29+
self.fc = torch.nn.Linear(
30+
in_features=in_features,
31+
out_features=out_features,
32+
bias=bias,
33+
)
34+
35+
def get_inputs(self):
36+
return self.inputs
37+
38+
def forward(self, x):
39+
return self.fc(x)
40+
41+
42+
class TestDumpPartitionedArtifact(unittest.TestCase):
43+
def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None):
44+
(
45+
ArmTester(
46+
module,
47+
inputs=module.get_inputs(),
48+
profile=TosaProfile.MI,
49+
backend=ArmBackendSelector.TOSA,
50+
)
51+
.export()
52+
.to_edge()
53+
.partition()
54+
.dump_artifact(dump_file)
55+
.dump_artifact()
56+
)
57+
58+
def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None):
59+
(
60+
ArmTester(
61+
module,
62+
inputs=module.get_inputs(),
63+
profile=TosaProfile.BI,
64+
backend=ArmBackendSelector.TOSA,
65+
)
66+
.quantize()
67+
.export()
68+
.to_edge()
69+
.partition()
70+
.dump_artifact(dump_file)
71+
.dump_artifact()
72+
)
73+
74+
def _is_tosa_marker_in_file(self, tmp_file):
75+
for line in open(tmp_file).readlines():
76+
if "'name': 'main'" in line:
77+
return True
78+
return False
79+
80+
def test_MI_artifact(self):
81+
model = Linear(20, 30)
82+
tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt")
83+
self._tosa_MI_pipeline(model, dump_file=tmp_file)
84+
assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
85+
if self._is_tosa_marker_in_file(tmp_file):
86+
return # Implicit pass test
87+
self.fail("File does not contain TOSA dump!")
88+
89+
def test_BI_artifact(self):
90+
model = Linear(20, 30)
91+
tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt")
92+
self._tosa_BI_pipeline(model, dump_file=tmp_file)
93+
assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
94+
if self._is_tosa_marker_in_file(tmp_file):
95+
return # Implicit pass test
96+
self.fail("File does not contain TOSA dump!")

backends/arm/test/tester/arm_tester.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,33 @@ class ArmBackendSelector(Enum):
4141
ETHOS_U55 = "ethos-u55"
4242

4343

44+
class Partition(Partition):
45+
def dump_artifact(self, path_to_dump: Optional[str]):
46+
super().dump_artifact(path_to_dump)
47+
from pprint import pformat
48+
49+
to_print = None
50+
for spec in self.graph_module.lowered_module_0.compile_specs:
51+
if spec.key == "output_format":
52+
if spec.value == b"tosa":
53+
tosa_fb = self.graph_module.lowered_module_0.processed_bytes
54+
to_print = TosaTestUtils.dbg_tosa_fb_to_json(tosa_fb)
55+
to_print = pformat(to_print, compact=True, indent=1)
56+
to_print = f"\n TOSA deserialized: \n{to_print}"
57+
elif spec.value == b"vela":
58+
vela_cmd_stream = self.graph_module.lowered_module_0.processed_bytes
59+
to_print = str(vela_cmd_stream)
60+
to_print = f"\n Vela command stream: \n{to_print}"
61+
break
62+
assert to_print is not None, "No TOSA nor Vela compile spec found"
63+
64+
if path_to_dump:
65+
with open(path_to_dump, "a") as fp:
66+
fp.write(to_print)
67+
else:
68+
print(to_print)
69+
70+
4471
class ArmTester(Tester):
4572
def __init__(
4673
self,

backends/arm/test/tosautil/tosa_test_utils.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import subprocess
1111
import tempfile
1212

13-
from typing import List, Optional, Tuple
13+
from typing import Dict, List, Optional, Tuple
1414

1515
import numpy as np
1616
import torch
@@ -52,49 +52,43 @@ def __init__(
5252
self.intermediate_path
5353
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
5454

55-
def dbg_dump_readble_tosa_file(self) -> None:
55+
@staticmethod
56+
def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
5657
"""
57-
This function is used to dump the TOSA buffer to a human readable
58-
format, using flatc.
59-
It requires the following files to be present on disk:
60-
1) output.tosa (in self.intermediate_path, produced by arm_backend.py)
61-
2) ./backends/arm/third-party/serialization_lib/schema/tosa.fbs.
62-
63-
It is used for debugging purposes.
64-
65-
Output from this is a file called output.json, located in
66-
self.intermediate_path.
67-
68-
Todo:
69-
* I'd prefer if this function didn't use files on disk...
70-
* Check if we can move this function to dump_artificat() thingy...
58+
This function is used to dump the TOSA flatbuffer to a human readable
59+
format, using flatc. It is used for debugging purposes.
7160
"""
7261

73-
tosa_input_file = self.intermediate_path + "/output.tosa"
62+
tmp = tempfile.mkdtemp()
63+
tosa_input_file = os.path.join(tmp, "output.tosa")
64+
with open(tosa_input_file, "wb") as f:
65+
f.write(tosa_fb)
66+
7467
tosa_schema_file = (
7568
"./backends/arm/third-party/serialization_lib/schema/tosa.fbs"
7669
)
77-
7870
assert os.path.exists(
7971
tosa_schema_file
8072
), f"tosa_schema_file: {tosa_schema_file} does not exist"
81-
assert os.path.exists(
82-
tosa_input_file
83-
), f"tosa_input_file: {tosa_input_file} does not exist"
84-
assert shutil.which("flatc") is not None
8573

74+
assert shutil.which("flatc") is not None
8675
cmd_flatc = [
8776
"flatc",
77+
"--json",
78+
"--strict-json",
8879
"-o",
89-
self.intermediate_path,
80+
tmp,
9081
"--raw-binary",
9182
"-t",
9283
tosa_schema_file,
9384
"--",
9485
tosa_input_file,
9586
]
96-
self._run_cmd(cmd_flatc)
97-
return
87+
TosaTestUtils._run_cmd(cmd_flatc)
88+
with open(os.path.join(tmp, "output.json"), "r") as f:
89+
json_out = json.load(f)
90+
91+
return json_out
9892

9993
def run_tosa_ref_model(
10094
self,
@@ -191,7 +185,7 @@ def run_tosa_ref_model(
191185
shutil.which(self.tosa_ref_model_path) is not None
192186
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
193187
cmd_ref_model = [self.tosa_ref_model_path, "--test_desc", desc_file_path]
194-
self._run_cmd(cmd_ref_model)
188+
TosaTestUtils._run_cmd(cmd_ref_model)
195189

196190
# Load desc.json, just to get the name of the output file above
197191
with open(desc_file_path) as f:
@@ -214,7 +208,8 @@ def run_tosa_ref_model(
214208

215209
return tosa_ref_output
216210

217-
def _run_cmd(self, cmd: List[str]) -> None:
211+
@staticmethod
212+
def _run_cmd(cmd: List[str]) -> None:
218213
"""
219214
Run a command and check for errors.
220215
Submodule serialization_lib updated from bd8c529 to 187af0d

0 commit comments

Comments
 (0)