Skip to content

Commit 0b1c1e5

Browse files
authored
Fix errors in comparison functions when dtype is uint8
Differential Revision: D67163600 Pull Request resolved: #7313
1 parent 6941d46 commit 0b1c1e5

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str):
372372

373373
def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput):
374374
def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
375-
return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2)
375+
return round((torch.pow((a - b), 2)).mean().item(), 2)
376376

377377
results = []
378378
for ref_value, value in zip(ref_values, values):
379379
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
380380
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
381-
results.append(mean_squared_error(ref_value, value))
381+
results.append(
382+
mean_squared_error(ref_value.to(torch.float32), value.to(torch.float32))
383+
)
382384
else:
383385
results.append(None)
384386

@@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387389

388390
def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput):
389391
def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
390-
signal = signal.type(torch.float32)
391-
noise = noise.type(torch.float32)
392392
signal_power = torch.mean(torch.pow(signal, 2))
393393
noise_power = torch.mean(torch.pow(noise, 2))
394394
snr = 10 * torch.log10(signal_power / noise_power)
@@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398398
for ref_value, value in zip(ref_values, values):
399399
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
400400
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
401-
diff = ref_value - value
402-
snr = signal_to_noise(ref_value, diff)
401+
ref_value_fp = ref_value.to(torch.float32)
402+
value_fp = value.to(torch.float32)
403+
diff = ref_value_fp - value_fp
404+
snr = signal_to_noise(ref_value_fp, diff)
403405
results.append(snr)
404406
else:
405407
results.append(None)
@@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429431
for ref_value, value in zip(ref_values, values):
430432
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
431433
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
432-
results.append(cosine_similarity(ref_value, value))
434+
results.append(
435+
cosine_similarity(ref_value.to(torch.float32), value.to(torch.float32))
436+
)
433437
else:
434438
results.append(None)
435439

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
2727
from executorch.devtools.inspector._inspector_utils import (
28+
calculate_cosine_similarity,
29+
calculate_mse,
30+
calculate_snr,
2831
calculate_time_scale_factor,
2932
create_debug_handle_to_op_node_mapping,
3033
EDGE_DIALECT_GRAPH_KEY,
@@ -188,6 +191,32 @@ def test_calculate_time_scale_factor_cycles(self):
188191
calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1
189192
)
190193

194+
def test_compare_results(self):
195+
a = torch.rand(4, 4)
196+
197+
# Create tensor b which has very close value to tensor a
198+
b = a.clone()
199+
b[0, 0] += 1e-2
200+
b[1, 0] += 1e-2
201+
b[1, 3] -= 1e-2
202+
203+
self.assertLess(calculate_mse([a], [b])[0], 0.5)
204+
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
205+
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
206+
207+
def test_compare_results_uint8(self):
208+
a = torch.randint(0, 255, (4, 4), dtype=torch.uint8)
209+
210+
# Create tensor b which has very close value to tensor a
211+
b = a.clone()
212+
b[0, 0] += 1
213+
b[1, 0] += 1
214+
b[1, 3] -= 1
215+
216+
self.assertLess(calculate_mse([a], [b])[0], 0.5)
217+
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
218+
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
219+
191220

192221
def gen_mock_operator_graph_with_expected_map() -> (
193222
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)