|
5 | 5 |
|
6 | 6 | import logging
|
7 | 7 | import time
|
| 8 | +import unittest |
| 9 | +from typing import Tuple |
8 | 10 |
|
9 | 11 | import torch
|
10 | 12 | from torch.export.exported_program import ExportedProgram
|
11 | 13 |
|
12 | 14 |
|
13 |
| -def assert_outputs_equal(model_output, ref_output): |
14 |
| - """ |
15 |
| - Helper testing function that asserts that the model output and the reference output |
16 |
| - are equal with some tolerance. Due to numerical differences between eager mode and |
17 |
| - the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and |
18 |
| - relative tolerance is 1e-3. |
19 |
| - """ |
20 |
| - |
21 |
| - # Compare the result from executor and eager mode direclty |
22 |
| - if isinstance(ref_output, tuple) or isinstance(ref_output, list): |
23 |
| - # Multiple outputs executor always returns tuple, even if there is one output |
24 |
| - assert len(ref_output) == len( |
25 |
| - model_output |
26 |
| - ), "Length of outputs is not matching!" |
27 |
| - for i in range(len(ref_output)): |
28 |
| - assert torch.allclose( |
29 |
| - model_output[i], ref_output[i], atol=1e-03, rtol=1e-03 |
30 |
| - ) |
31 |
| - else: |
32 |
| - # If one output, eager returns tensor while executor tuple of size 1 |
33 |
| - assert torch.allclose( |
34 |
| - model_output[0], ref_output, atol=1e-03, rtol=1e-03 |
35 |
| - ), "Outputs are not matching!" |
| 15 | +class TestModule(unittest.TestCase): |
| 16 | + def assert_outputs_equal(self, model_output, ref_output, use_fp16: bool = False): |
| 17 | + """ |
| 18 | + Helper testing function that asserts that the model output and the reference output |
| 19 | + are equal with some tolerance. Due to numerical differences between eager mode and |
| 20 | + the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and |
| 21 | + relative tolerance is 1e-3. |
| 22 | + """ |
| 23 | + # Compare the result from executor and eager mode direclty |
| 24 | + if isinstance(ref_output, tuple) or isinstance(ref_output, list): |
| 25 | + # Multiple outputs executor always returns tuple, even if there is one output |
| 26 | + assert len(ref_output) == len( |
| 27 | + model_output |
| 28 | + ), "Length of outputs is not matching!" |
| 29 | + for i in range(len(ref_output)): |
| 30 | + res_output = model_output[i].cpu() |
| 31 | + ref_output = ref_output[i].cpu() |
| 32 | + if use_fp16 and ref_output.dtype == torch.float16: |
| 33 | + # cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default) |
| 34 | + ref_output = ref_output.to(torch.float32) |
| 35 | + |
| 36 | + mean_err = ((res_output - ref_output).abs() / ref_output).mean() |
| 37 | + logging.info(f"mean err = {mean_err}") |
| 38 | + self.assertLess(mean_err, 0.05) |
| 39 | + else: |
| 40 | + # If one output, eager returns tensor while executor tuple of size 1 |
| 41 | + if use_fp16 and ref_output.dtype == torch.float16: |
| 42 | + # cast back from fp16 to fp32 (ExecuTorch results are in FP32 by default) |
| 43 | + ref_output = ref_output.to(torch.float32) |
| 44 | + ref_output = ref_output.cpu() |
| 45 | + res_output = model_output[0].cpu() |
| 46 | + mean_err = ((res_output - ref_output).abs() / ref_output).mean() |
| 47 | + logging.info(f"mean err = {mean_err}") |
| 48 | + self.assertLess(mean_err, 0.05) |
36 | 49 |
|
37 | 50 |
|
38 | 51 | def bench_forward(func, *args):
|
@@ -101,17 +114,31 @@ def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name):
|
101 | 114 | )
|
102 | 115 |
|
103 | 116 |
|
104 |
| -def compare_outputs(executorch_program: ExportedProgram, model, inputs, model_name): |
| 117 | +def compare_outputs( |
| 118 | + executorch_program: ExportedProgram, |
| 119 | + model: torch.nn.Module, |
| 120 | + inputs: Tuple[torch.tensor], |
| 121 | + model_name: str, |
| 122 | + use_fp16: bool, |
| 123 | +): |
| 124 | + test_module = TestModule() |
105 | 125 | inputs_copy = []
|
| 126 | + if use_fp16: |
| 127 | + model = model.to(torch.float16) |
| 128 | + model = model |
106 | 129 | for t in inputs:
|
107 |
| - inputs_copy.append(t.detach().clone()) |
| 130 | + tensor = t.detach().clone() |
| 131 | + if use_fp16 and tensor.dtype == torch.float32: |
| 132 | + tensor = tensor.to(torch.float16) |
| 133 | + inputs_copy.append(tensor) |
108 | 134 | inputs_copy = tuple(inputs_copy)
|
109 | 135 |
|
110 |
| - pytorch_results = model(*inputs) |
| 136 | + pytorch_results = model(*inputs_copy) |
| 137 | + |
111 | 138 | executorch_model = get_executorch_model(executorch_program)
|
112 | 139 | if executorch_model is not None:
|
113 |
| - executorch_results = executorch_model.forward(inputs_copy) |
114 |
| - assert_outputs_equal(executorch_results, pytorch_results) |
| 140 | + executorch_results = executorch_model.forward(inputs) |
| 141 | + test_module.assert_outputs_equal(executorch_results, pytorch_results, use_fp16) |
115 | 142 | logging.info(
|
116 | 143 | f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!"
|
117 | 144 | )
|
0 commit comments