Skip to content

Commit 7aa39b4

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Rewrite etdump debug data value comparision (#4152)
Summary: Pull Request resolved: #4152 The original comparison actually compares the metadata of 2 tensors, not the tensors themselves, and would fail when the 2 tensors are written at different locations in the buffer (because of different [offsets](https://www.internalfb.com/code/fbsource/[02da17b6e421d91ada2fd690e9f9ecfdb4bedfc1]/fbcode/executorch/sdk/etdump/schema_flatcc.py?lines=26)), even if their values are the same. Therefore, change to the new compaision logic which compares the actual values. Differential Revision: D59350018
1 parent 46b10a7 commit 7aa39b4

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

sdk/inspector/_inspector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
gen_graphs_from_etrecord,
4747
inflate_runtime_output,
4848
is_debug_output,
49+
is_inference_output_equal,
4950
ProgramOutput,
5051
RESERVED_FRAMEWORK_EVENT_NAMES,
5152
TIME_SCALE_DICT,
@@ -571,8 +572,10 @@ def _populate_debugging_related_fields(
571572
debug_data = [debug_event.debug_entry for debug_event in debug_events]
572573
else:
573574
for debug_event, value in zip(debug_events, debug_data):
574-
assert (
575-
debug_event.debug_entry == value
575+
v1 = inflate_runtime_output(debug_event.debug_entry, output_buffer)
576+
v2 = inflate_runtime_output(value, output_buffer)
577+
assert is_inference_output_equal(
578+
v1, v2
576579
), """Corresponding debug events in multiple iterations of the model
577580
must have the same debug entry values. This is not the case for the
578581
intermediate data present in this ETDump and indicates potential issues

sdk/inspector/_inspector_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ class TimeScale(Enum):
7070
ProgramOutput: TypeAlias = List[InferenceOutput]
7171

7272

73+
# Compare whether two InferenceOutputs are equal
74+
def is_inference_output_equal(
75+
output1: InferenceOutput, output2: InferenceOutput
76+
) -> bool:
77+
if isinstance(output1, torch.Tensor) and isinstance(output2, torch.Tensor):
78+
return torch.equal(output1, output2)
79+
elif isinstance(output1, List) and isinstance(output2, List):
80+
return all(torch.equal(t1, t2) for t1, t2 in zip(output1, output2))
81+
elif output1 == output2:
82+
return True
83+
else:
84+
return False
85+
86+
7387
# Given a ETDump Tensor object and offset, extract into a torch.Tensor
7488
def _parse_tensor_value(
7589
tensor: Optional[Tensor], output_buffer: Optional[bytes]

sdk/inspector/tests/inspector_utils_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import unittest
99
from typing import Dict, Tuple
1010

11+
import torch
12+
1113
from executorch.sdk import generate_etrecord, parse_etrecord
1214

1315
from executorch.sdk.debug_format.base_schema import (
@@ -25,6 +27,7 @@
2527
EDGE_DIALECT_GRAPH_KEY,
2628
find_populated_event,
2729
gen_graphs_from_etrecord,
30+
is_inference_output_equal,
2831
)
2932

3033

@@ -126,6 +129,30 @@ def test_find_populated_event(self):
126129
)
127130
self.assertEqual(find_populated_event(event), profile_event)
128131

132+
def test_is_inference_output_equal(self):
133+
# Compare tensors. Not equal because of different values
134+
self.assertFalse(
135+
is_inference_output_equal(
136+
torch.tensor([[2, 1], [4, 3]]),
137+
torch.tensor([[5, 6], [7, 8]]),
138+
)
139+
)
140+
141+
# Compare tensor lists
142+
tensor_list_1 = (
143+
[
144+
torch.tensor([[1, 2], [3, 4]]),
145+
torch.tensor([[1, 2], [3, 4]]),
146+
torch.tensor([[1, 2], [3, 4]]),
147+
],
148+
)
149+
tensor_list_2 = [
150+
torch.tensor([[1, 2], [3, 4]]),
151+
torch.tensor([[1, 2], [3, 4]]),
152+
]
153+
# Not equal because of different number of tensors
154+
self.assertFalse(is_inference_output_equal(tensor_list_1, tensor_list_2))
155+
129156

130157
def gen_mock_operator_graph_with_expected_map() -> (
131158
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)