Skip to content

Commit 4bbe994

Browse files
authored
Run tosa_reference_model using python binding (#6658)
This change makes it uneccessary to dump intermediates by default for running the reference_model
1 parent 713d8a1 commit 4bbe994

File tree

8 files changed

+96
-59
lines changed

8 files changed

+96
-59
lines changed

backends/arm/arm_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -31,6 +31,7 @@
3131
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
3333
from torch.export.exported_program import ExportedProgram
34+
from torch.fx import Node
3435

3536
# TOSA backend debug functionality
3637
logger = logging.getLogger(__name__)
@@ -225,6 +226,7 @@ def preprocess( # noqa: C901
225226
node_visitors = get_node_visitors(edge_program)
226227

227228
for node in graph_module.graph.nodes:
229+
node = cast(Node, node)
228230
if node.op == "call_function":
229231
process_call_function(node, tosa_graph, node_visitors)
230232
elif node.op == "placeholder":
@@ -236,9 +238,6 @@ def preprocess( # noqa: C901
236238
# any checking of compatibility.
237239
dbg_fail(node, tosa_graph, artifact_path)
238240

239-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
240-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
241-
# access from top level.
242241
if artifact_path:
243242
tag = _get_first_delegation_tag(graph_module)
244243
dbg_tosa_dump(
@@ -259,6 +258,4 @@ def preprocess( # noqa: C901
259258
else:
260259
raise RuntimeError(f"Unknown format {output_format}")
261260

262-
# Continueing from above. Can I put tosa_graph into this function?
263-
# debug_handle_map = ...
264261
return PreprocessResult(processed_bytes=binary)

backends/arm/test/common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,15 @@ def get_tosa_compile_spec_unbuilt(
192192
the compile spec before calling .build() to finalize it.
193193
"""
194194
if not custom_path:
195-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
196-
prefix="arm_tosa_"
197-
)
198-
else:
199-
intermediate_path = custom_path
195+
custom_path = maybe_get_tosa_collate_path()
200196

201-
if not os.path.exists(intermediate_path):
202-
os.makedirs(intermediate_path, exist_ok=True)
197+
if custom_path is not None and not os.path.exists(custom_path):
198+
os.makedirs(custom_path, exist_ok=True)
203199
compile_spec_builder = (
204200
ArmCompileSpecBuilder()
205201
.tosa_compile_spec()
206202
.set_permute_memory_format(permute_memory_to_nhwc)
207-
.dump_intermediate_artifacts_to(intermediate_path)
203+
.dump_intermediate_artifacts_to(custom_path)
208204
)
209205

210206
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
110+
compile_spec=common.get_tosa_compile_spec(
111+
permute_memory_to_nhwc=True,
112+
custom_path=tempfile.mkdtemp("diff_print_test"),
113+
),
111114
)
112115
.export()
113116
.to_edge()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
121121
def test_cat_4d_tosa_MI(self):
122122
square = torch.ones((2, 2, 2, 2))
123123
for dim in range(-3, 3):
124-
test_data = ((square, square), dim)
124+
test_data = ((square, square.clone()), dim)
125125
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126126

127127
@parameterized.expand(Cat.test_parameters)

backends/arm/test/ops/test_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def _test_select_tosa_BI_pipeline(
9393
.check(["torch.ops.quantized_decomposed"])
9494
.to_edge()
9595
.partition()
96-
.dump_artifact()
97-
.dump_operator_distribution()
9896
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9997
.to_executorch()
10098
.run_method_and_compare_outputs(inputs=test_data)
@@ -162,12 +160,14 @@ def test_select_int_tosa_MI(self, test_data: test_data_t):
162160
)
163161

164162
@parameterized.expand(test_data_suite)
163+
@unittest.skip
165164
def test_select_copy_tosa_BI(self, test_data: test_data_t):
166165
self._test_select_tosa_BI_pipeline(
167166
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int"
168167
)
169168

170169
@parameterized.expand(test_data_suite)
170+
@unittest.skip
171171
def test_select_int_tosa_BI(self, test_data: test_data_t):
172172
self._test_select_tosa_BI_pipeline(
173173
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int"

backends/arm/test/runner_utils.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
import numpy as np
1818
import torch
1919

20+
import tosa_reference_model
21+
2022
from torch.export import ExportedProgram
2123
from torch.fx.node import Node
24+
from tosa import TosaGraph
2225

2326
logger = logging.getLogger(__name__)
24-
logger.setLevel(logging.WARNING)
27+
logger.setLevel(logging.CRITICAL)
2528

2629

2730
class QuantizationParams:
@@ -167,7 +170,7 @@ def __init__(
167170
):
168171
self.intermediate_path = intermediate_path
169172
self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
170-
assert os.path.exists(
173+
assert self.intermediate_path is None or os.path.exists(
171174
self.intermediate_path
172175
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
173176

@@ -323,7 +326,46 @@ def run_corstone(
323326
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
324327
output_shape = self.output_node.args[0][0].meta["val"].shape
325328
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
326-
return [tosa_ref_output]
329+
return tosa_ref_output
330+
331+
def run_tosa_graph(
332+
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
333+
) -> torch.Tensor:
334+
"""Runs the TOSA reference model with inputs and returns the result."""
335+
data_np = [
336+
prep_data_for_save(
337+
input, self.is_quantized, self.input_names[i], self.qp_input[i]
338+
)
339+
for i, input in enumerate(inputs)
340+
]
341+
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
342+
tosa_profile = 0 if self.is_quantized else 1
343+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
344+
outputs, status = tosa_reference_model.run(
345+
graph,
346+
data_np,
347+
verbosity=_tosa_refmodel_loglevel(logger.level),
348+
tosa_profile=tosa_profile,
349+
initialize_variable_tensor_from_numpy=1, # True
350+
debug_mode=debug_mode,
351+
)
352+
353+
assert (
354+
status == tosa_reference_model.GraphStatus.TOSA_VALID
355+
), "Non-valid TOSA given to reference model."
356+
357+
outputs_torch = []
358+
for output in outputs:
359+
output = output.astype(np.float32)
360+
if self.is_quantized:
361+
# Need to dequant back to FP32 for comparison with torch output
362+
quant_param = self.qp_output
363+
assert (
364+
quant_param is not None
365+
), "There are no quantization parameters, check output parameters"
366+
output = (output - quant_param.zp) * quant_param.scale
367+
outputs_torch.append(torch.from_numpy(output))
368+
return tuple(outputs_torch)
327369

328370
def run_tosa_ref_model(
329371
self,
@@ -408,21 +450,13 @@ def run_tosa_ref_model(
408450
assert (
409451
shutil.which(self.tosa_ref_model_path) is not None
410452
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
411-
loglevel_map = {
412-
logging.INFO: "INFO",
413-
logging.CRITICAL: "LOW",
414-
logging.ERROR: "LOW",
415-
logging.WARNING: "MED",
416-
logging.DEBUG: "HIGH",
417-
logging.NOTSET: "MED",
418-
}
419-
clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
453+
420454
cmd_ref_model = [
421455
self.tosa_ref_model_path,
422456
"--test_desc",
423457
desc_file_path,
424458
"-l",
425-
loglevel_map[clamped_logging_level],
459+
_tosa_refmodel_loglevel(logger.level),
426460
]
427461
_run_cmd(cmd_ref_model)
428462

@@ -458,7 +492,10 @@ def run_tosa_ref_model(
458492

459493

460494
def prep_data_for_save(
461-
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
495+
data: torch.Tensor,
496+
is_quantized: bool,
497+
input_name: str,
498+
quant_param: QuantizationParams,
462499
):
463500
data_np = np.array(data.detach(), order="C").astype(
464501
f"{data.dtype}".replace("torch.", "")
@@ -602,3 +639,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
602639
pass
603640

604641
return json_out
642+
643+
644+
def _tosa_refmodel_loglevel(loglevel: int) -> str:
645+
"""Converts a logging loglevel to tosa_reference_model logginglevel,
646+
returned as string.
647+
"""
648+
loglevel_map = {
649+
logging.INFO: "INFO",
650+
logging.CRITICAL: "LOW",
651+
logging.ERROR: "LOW",
652+
logging.WARNING: "MED",
653+
logging.DEBUG: "HIGH",
654+
logging.NOTSET: "MED",
655+
}
656+
clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0)
657+
return loglevel_map[clamped_logging_level]

backends/arm/test/tester/arm_tester.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from executorch.backends.xnnpack.test.tester import Tester
4141
from executorch.devtools.backend_debug import get_delegation_info
42-
from executorch.exir import EdgeCompileConfig
42+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
4343
from executorch.exir.backend.compile_spec_schema import CompileSpec
4444

4545
from executorch.exir.lowered_backend_module import LoweredBackendModule
@@ -120,10 +120,15 @@ def __init__(
120120
super().__init__(dynamic_shapes)
121121
self.tosa_test_util = tosa_test_util
122122

123+
def run(self, artifact: EdgeProgramManager, inputs=None):
124+
self.executorch_program = artifact.to_executorch(self.config)
125+
if module := getattr(
126+
artifact.exported_program().graph_module, "lowered_module_0", None
127+
):
128+
self.buffer = module.processed_bytes
129+
123130
def run_artifact(self, inputs):
124-
tosa_output = self.tosa_test_util.run_tosa_ref_model(
125-
inputs=inputs,
126-
)
131+
tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs)
127132
return tosa_output
128133

129134

@@ -316,7 +321,7 @@ def run_method_and_compare_outputs(
316321
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
317322

318323
reference_output = reference_stage.run_artifact(reference_input)
319-
test_output = tuple(test_stage.run_artifact(test_input))
324+
test_output = test_stage.run_artifact(test_input)
320325
if (
321326
is_nhwc
322327
and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]

examples/arm/setup.sh

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ ethos_u_base_rev="24.08"
8888

8989
# tosa reference model
9090
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
91-
tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5"
91+
tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283"
9292

9393
########
9494
### Mandatory user args
@@ -227,30 +227,13 @@ function setup_tosa_reference_model() {
227227
cd reference_model
228228
git checkout ${tosa_reference_model_rev}
229229
git submodule update --init --recursive
230-
cd ..
231-
fi
232-
cd reference_model
233-
mkdir -p build
234-
cd build
235-
cmake ..
236-
237-
# make use of half the cores for building
238-
if [[ "${OS}" == "Linux" ]]; then
239-
n=$(( $(nproc) / 2 ))
240-
elif [[ "${OS}" == "Darwin" ]]; then
241-
n=$(( $(sysctl -n hw.logicalcpu) / 2 ))
242-
else
243-
n=1
244230
fi
245231

246-
if [[ "$n" -lt 1 ]]; then
247-
n=1
248-
fi
232+
echo "pip installing reference_model..."
233+
repo_dir="${root_dir}/reference_model"
234+
cd $repo_dir
235+
pip install .
249236

250-
make -j"${n}"
251-
cd reference_model
252-
tosa_bin_path=`pwd`
253-
echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}"
254237
}
255238

256239
function setup_vela() {

0 commit comments

Comments
 (0)