Skip to content

Commit e3730f8

Browse files
committed
Fix asserts between PyTorch MPS backend and ExecuTorch MPS delegate (pytorch#16)
* Fix asserts between PyTorch MPS backend and ExecuTorch MPS delegate * Fix lint
1 parent a03dfb5 commit e3730f8

File tree

2 files changed

+60
-29
lines changed

2 files changed

+60
-29
lines changed

examples/apple/mps/scripts/bench_utils.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,47 @@
55

66
import logging
77
import time
8+
import unittest
9+
from typing import Tuple
810

911
import torch
1012
from torch.export.exported_program import ExportedProgram
1113

1214

13-
def assert_outputs_equal(model_output, ref_output):
14-
"""
15-
Helper testing function that asserts that the model output and the reference output
16-
are equal with some tolerance. Due to numerical differences between eager mode and
17-
the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and
18-
relative tolerance is 1e-3.
19-
"""
20-
21-
# Compare the result from executor and eager mode direclty
22-
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
23-
# Multiple outputs executor always returns tuple, even if there is one output
24-
assert len(ref_output) == len(
25-
model_output
26-
), "Length of outputs is not matching!"
27-
for i in range(len(ref_output)):
28-
assert torch.allclose(
29-
model_output[i], ref_output[i], atol=1e-03, rtol=1e-03
30-
)
31-
else:
32-
# If one output, eager returns tensor while executor tuple of size 1
33-
assert torch.allclose(
34-
model_output[0], ref_output, atol=1e-03, rtol=1e-03
35-
), "Outputs are not matching!"
15+
class TestModule(unittest.TestCase):
16+
def assert_outputs_equal(self, model_output, ref_output, use_fp16: bool = False):
17+
"""
18+
Helper testing function that asserts that the model output and the reference output
19+
are equal with some tolerance. Due to numerical differences between eager mode and
20+
the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and
21+
relative tolerance is 1e-3.
22+
"""
23+
# Compare the result from executor and eager mode direclty
24+
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
25+
# Multiple outputs executor always returns tuple, even if there is one output
26+
assert len(ref_output) == len(
27+
model_output
28+
), "Length of outputs is not matching!"
29+
for i in range(len(ref_output)):
30+
res_output = model_output[i].cpu()
31+
ref_output = ref_output[i].cpu()
32+
if use_fp16 and ref_output.dtype == torch.float16:
33+
# cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default)
34+
ref_output = ref_output.to(torch.float32)
35+
36+
mean_err = ((res_output - ref_output).abs() / ref_output).mean()
37+
logging.info(f"mean err = {mean_err}")
38+
self.assertLess(mean_err, 0.05)
39+
else:
40+
# If one output, eager returns tensor while executor tuple of size 1
41+
if use_fp16 and ref_output.dtype == torch.float16:
42+
# cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default)
43+
ref_output = ref_output.to(torch.float32)
44+
ref_output = ref_output.cpu()
45+
res_output = model_output[0].cpu()
46+
mean_err = ((res_output - ref_output).abs() / ref_output).mean()
47+
logging.info(f"mean err = {mean_err}")
48+
self.assertLess(mean_err, 0.05)
3649

3750

3851
def bench_forward(func, *args):
@@ -101,17 +114,31 @@ def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name):
101114
)
102115

103116

104-
def compare_outputs(executorch_program: ExportedProgram, model, inputs, model_name):
117+
def compare_outputs(
118+
executorch_program: ExportedProgram,
119+
model: torch.nn.Module,
120+
inputs: Tuple[torch.tensor],
121+
model_name: str,
122+
use_fp16: bool,
123+
):
124+
test_module = TestModule()
105125
inputs_copy = []
126+
if use_fp16:
127+
model = model.to(torch.float16)
128+
model = model
106129
for t in inputs:
107-
inputs_copy.append(t.detach().clone())
130+
tensor = t.detach().clone()
131+
if use_fp16 and tensor.dtype == torch.float32:
132+
tensor = tensor.to(torch.float16)
133+
inputs_copy.append(tensor)
108134
inputs_copy = tuple(inputs_copy)
109135

110-
pytorch_results = model(*inputs)
136+
pytorch_results = model(*inputs_copy)
137+
111138
executorch_model = get_executorch_model(executorch_program)
112139
if executorch_model is not None:
113-
executorch_results = executorch_model.forward(inputs_copy)
114-
assert_outputs_equal(executorch_results, pytorch_results)
140+
executorch_results = executorch_model.forward(inputs)
141+
test_module.assert_outputs_equal(executorch_results, pytorch_results, use_fp16)
115142
logging.info(
116143
f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!"
117144
)

examples/apple/mps/scripts/mps_example.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def get_model_config(args):
155155
model, example_inputs, _ = EagerModelFactory.create_model(**model_config)
156156

157157
model = model.eval()
158+
159+
# Deep copy the model inputs to check against PyTorch forward pass
158160
if args.check_correctness or args.bench_pytorch:
159161
model_copy = copy.deepcopy(model)
160162
inputs_copy = []
@@ -228,4 +230,6 @@ def get_model_config(args):
228230
bench_torch(executorch_program, model_copy, example_inputs, model_name)
229231

230232
if args.check_correctness:
231-
compare_outputs(executorch_program, model_copy, inputs_copy, model_name)
233+
compare_outputs(
234+
executorch_program, model_copy, inputs_copy, model_name, args.use_fp16
235+
)

0 commit comments

Comments
 (0)