Skip to content

Commit c337bef

Browse files
Add util function for pretty printing of output diffs (#7302)
This is to make run_method_and_compare_outputs less complex since lintrunner was complaining. Additionally moves out previous info dumps in compare_output into a new callback function to handle all error handling in the same way.
1 parent c1e137b commit c337bef

File tree

2 files changed

+292
-37
lines changed

2 files changed

+292
-37
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import tempfile
8+
9+
import torch
10+
from executorch.backends.arm.test.runner_utils import (
11+
_get_input_quantization_params,
12+
_get_output_node,
13+
_get_output_quantization_params,
14+
)
15+
16+
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
22+
23+
output_str = ""
24+
for c in range(C):
25+
if channels_close[c]:
26+
continue
27+
28+
max_diff = torch.max(torch.abs(reference - result))
29+
exp = f"{max_diff:2e}"[-3:]
30+
output_str += f"channel {c} (e{exp})\n"
31+
32+
for y in range(H):
33+
res = "["
34+
for x in range(W):
35+
if torch.allclose(reference[c, y, x], result[c, y, x], rtol, atol):
36+
res += " . "
37+
else:
38+
diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp))
39+
res += f"{diff: .2f} "
40+
41+
# Break early for large widths
42+
if x == 16:
43+
res += "..."
44+
break
45+
46+
res += "]\n"
47+
output_str += res
48+
49+
return output_str
50+
51+
52+
def _print_elements(result, reference, C, H, W, rtol, atol):
53+
output_str = ""
54+
for y in range(H):
55+
res = "["
56+
for x in range(W):
57+
result_channels = result[:, y, x]
58+
reference_channels = reference[:, y, x]
59+
60+
n_errors = 0
61+
for a, b in zip(result_channels, reference_channels):
62+
if not torch.allclose(a, b, rtol, atol):
63+
n_errors = n_errors + 1
64+
65+
if n_errors == 0:
66+
res += ". "
67+
else:
68+
res += f"{n_errors} "
69+
70+
# Break early for large widths
71+
if x == 16:
72+
res += "..."
73+
break
74+
75+
res += "]\n"
76+
output_str += res
77+
78+
return output_str
79+
80+
81+
def print_error_diffs(
82+
tester,
83+
result: torch.Tensor | tuple,
84+
reference: torch.Tensor | tuple,
85+
quantization_scale=None,
86+
atol=1e-03,
87+
rtol=1e-03,
88+
qtol=0,
89+
):
90+
"""
91+
Prints the error difference between a result tensor and a reference tensor in NCHW format.
92+
Certain formatting rules are applied to clarify errors:
93+
94+
- Batches are only expanded if they contain errors.
95+
-> Shows if errors are related to batch handling
96+
- If errors appear in all channels, only the number of errors in each HW element are printed.
97+
-> Shows if errors are related to HW handling
98+
- If at least one channel is free from errors, or if C==1, errors are printed channel by channel
99+
-> Shows if errors are related to channel handling or single errors such as rounding/quantization errors
100+
101+
Example output of shape (3,3,2,2):
102+
103+
############################ ERROR DIFFERENCE #############################
104+
BATCH 0
105+
.
106+
BATCH 1
107+
[. . ]
108+
[. 3 ]
109+
BATCH 2
110+
channel 1 (e-03)
111+
[ 1.85 . ]
112+
[ . 9.32 ]
113+
114+
MEAN MEDIAN MAX MIN (error as % of reference output range)
115+
60.02% 55.73% 100.17% 19.91%
116+
###########################################################################
117+
118+
119+
"""
120+
121+
if isinstance(reference, tuple):
122+
reference = reference[0]
123+
if isinstance(result, tuple):
124+
result = result[0]
125+
126+
if not result.shape == reference.shape:
127+
raise ValueError("Output needs to be of same shape")
128+
shape = result.shape
129+
130+
match len(shape):
131+
case 4:
132+
N, C, H, W = (shape[0], shape[1], shape[2], shape[3])
133+
case 3:
134+
N, C, H, W = (1, shape[0], shape[1], shape[2])
135+
case 2:
136+
N, C, H, W = (1, 1, shape[0], shape[1])
137+
case 1:
138+
N, C, H, W = (1, 1, 1, shape[0])
139+
case _:
140+
raise ValueError("Invalid tensor rank")
141+
142+
if quantization_scale is not None:
143+
atol += quantization_scale * qtol
144+
145+
# Reshape tensors to 4D NCHW format
146+
result = torch.reshape(result, (N, C, H, W))
147+
reference = torch.reshape(reference, (N, C, H, W))
148+
149+
output_str = ""
150+
for n in range(N):
151+
output_str += f"BATCH {n}\n"
152+
result_batch = result[n, :, :, :]
153+
reference_batch = reference[n, :, :, :]
154+
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
155+
if is_close:
156+
output_str += ".\n"
157+
else:
158+
channels_close = [None] * C
159+
for c in range(C):
160+
result_hw = result[n, c, :, :]
161+
reference_hw = reference[n, c, :, :]
162+
163+
channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol)
164+
165+
if any(channels_close) or len(channels_close) == 1:
166+
output_str += _print_channels(
167+
result[n, :, :, :],
168+
reference[n, :, :, :],
169+
channels_close,
170+
C,
171+
H,
172+
W,
173+
rtol,
174+
atol,
175+
)
176+
else:
177+
output_str += _print_elements(
178+
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
179+
)
180+
181+
reference_range = torch.max(reference) - torch.min(reference)
182+
diff = torch.abs(reference - result).flatten()
183+
diff = diff[diff.nonzero()]
184+
if not len(diff) == 0:
185+
diff_percent = diff / reference_range
186+
output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n"
187+
output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n"
188+
189+
# Over-engineer separators to match output width
190+
lines = output_str.split("\n")
191+
line_length = [len(line) for line in lines]
192+
longest_line = max(line_length)
193+
title = "# ERROR DIFFERENCE #"
194+
separator_length = max(longest_line, len(title))
195+
196+
pre_title_length = max(0, ((separator_length - len(title)) // 2))
197+
post_title_length = max(0, ((separator_length - len(title) + 1) // 2))
198+
start_separator = (
199+
"\n" + "#" * pre_title_length + title + "#" * post_title_length + "\n"
200+
)
201+
output_str = start_separator + output_str
202+
end_separator = "#" * separator_length + "\n"
203+
output_str += end_separator
204+
205+
logger.error(output_str)
206+
207+
208+
def dump_error_output(
209+
tester,
210+
reference_output,
211+
stage_output,
212+
quantization_scale=None,
213+
atol=1e-03,
214+
rtol=1e-03,
215+
qtol=0,
216+
):
217+
"""
218+
Prints Quantization info and error tolerances, and saves the differing tensors to disc.
219+
"""
220+
# Capture assertion error and print more info
221+
banner = "=" * 40 + "TOSA debug info" + "=" * 40
222+
logger.error(banner)
223+
path_to_tosa_files = tester.runner_util.intermediate_path
224+
225+
if path_to_tosa_files is None:
226+
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
227+
228+
export_stage = tester.stages.get(tester.stage_name(Export), None)
229+
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
230+
if export_stage is not None and quantize_stage is not None:
231+
output_node = _get_output_node(export_stage.artifact)
232+
qp_input = _get_input_quantization_params(export_stage.artifact)
233+
qp_output = _get_output_quantization_params(export_stage.artifact, output_node)
234+
logger.error(f"Input QuantArgs: {qp_input}")
235+
logger.error(f"Output QuantArgs: {qp_output}")
236+
237+
logger.error(f"{path_to_tosa_files=}")
238+
import os
239+
240+
torch.save(
241+
stage_output,
242+
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
243+
)
244+
torch.save(
245+
reference_output,
246+
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
247+
)
248+
logger.error(f"{atol=}, {rtol=}, {qtol=}")
249+
250+
251+
if __name__ == "__main__":
252+
import sys
253+
254+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
255+
256+
""" This is expected to produce the example output of print_diff"""
257+
torch.manual_seed(0)
258+
a = torch.rand(3, 3, 2, 2) * 0.01
259+
b = a.clone().detach()
260+
logger.info(b)
261+
262+
# Errors in all channels in element (1,1)
263+
a[1, :, 1, 1] = 0
264+
# Errors in (0,0) and (1,1) in channel 1
265+
a[2, 1, 1, 1] = 0
266+
a[2, 1, 0, 0] = 0
267+
268+
print_error_diffs(a, b)

backends/arm/test/tester/arm_tester.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
import tempfile
87

98
from collections import Counter
109
from pprint import pformat
@@ -25,12 +24,10 @@
2524
)
2625
from executorch.backends.arm.test.common import get_target_board
2726

28-
from executorch.backends.arm.test.runner_utils import (
29-
_get_input_quantization_params,
30-
_get_output_node,
31-
_get_output_quantization_params,
32-
dbg_tosa_fb_to_json,
33-
RunnerUtil,
27+
from executorch.backends.arm.test.runner_utils import dbg_tosa_fb_to_json, RunnerUtil
28+
from executorch.backends.arm.test.tester.analyze_output_utils import (
29+
dump_error_output,
30+
print_error_diffs,
3431
)
3532
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
3633

@@ -278,6 +275,7 @@ def run_method_and_compare_outputs(
278275
atol=1e-03,
279276
rtol=1e-03,
280277
qtol=0,
278+
error_callbacks=None,
281279
):
282280
"""
283281
Compares the run_artifact output of 'stage' with the output of a reference stage.
@@ -366,7 +364,13 @@ def run_method_and_compare_outputs(
366364
test_output = self.transpose_data_format(test_output, "NCHW")
367365

368366
self._compare_outputs(
369-
reference_output, test_output, quantization_scale, atol, rtol, qtol
367+
reference_output,
368+
test_output,
369+
quantization_scale,
370+
atol,
371+
rtol,
372+
qtol,
373+
error_callbacks,
370374
)
371375

372376
return self
@@ -515,42 +519,25 @@ def _compare_outputs(
515519
atol=1e-03,
516520
rtol=1e-03,
517521
qtol=0,
522+
error_callbacks=None,
518523
):
519524
try:
520525
super()._compare_outputs(
521526
reference_output, stage_output, quantization_scale, atol, rtol, qtol
522527
)
523528
except AssertionError as e:
524-
# Capture assertion error and print more info
525-
banner = "=" * 40 + "TOSA debug info" + "=" * 40
526-
logger.error(banner)
527-
path_to_tosa_files = self.runner_util.intermediate_path
528-
if path_to_tosa_files is None:
529-
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
530-
531-
export_stage = self.stages.get(self.stage_name(tester.Export), None)
532-
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
533-
if export_stage is not None and quantize_stage is not None:
534-
output_node = _get_output_node(export_stage.artifact)
535-
qp_input = _get_input_quantization_params(export_stage.artifact)
536-
qp_output = _get_output_quantization_params(
537-
export_stage.artifact, output_node
529+
if error_callbacks is None:
530+
error_callbacks = [print_error_diffs, dump_error_output]
531+
for callback in error_callbacks:
532+
callback(
533+
self,
534+
reference_output,
535+
stage_output,
536+
quantization_scale=None,
537+
atol=1e-03,
538+
rtol=1e-03,
539+
qtol=0,
538540
)
539-
logger.error(f"Input QuantArgs: {qp_input}")
540-
logger.error(f"Output QuantArgs: {qp_output}")
541-
542-
logger.error(f"{path_to_tosa_files=}")
543-
import os
544-
545-
torch.save(
546-
stage_output,
547-
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
548-
)
549-
torch.save(
550-
reference_output,
551-
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
552-
)
553-
logger.error(f"{atol=}, {rtol=}, {qtol=}")
554541
raise e
555542

556543

0 commit comments

Comments
 (0)