Skip to content

Commit 9e6f4a9

Browse files
guangy10facebook-github-bot
authored andcommitted
Fix export tests
Differential Revision: D55382428
1 parent cde514c commit 9e6f4a9

File tree

2 files changed

+60
-57
lines changed

2 files changed

+60
-57
lines changed

examples/portable/test/test_export.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,27 @@
66

77
import unittest
88

9-
from typing import Any, Callable
10-
119
import torch
1210
from executorch.examples.models import MODEL_NAME_TO_MODEL
1311
from executorch.examples.models.model_factory import EagerModelFactory
1412

15-
from executorch.examples.portable.utils import export_to_edge
16-
1713
from executorch.extension.pybindings.portable_lib import ( # @manual
1814
_load_for_executorch_from_buffer,
1915
)
2016

17+
from ..utils import export_to_edge
18+
2119

2220
class ExportTest(unittest.TestCase):
23-
def _assert_eager_lowered_same_result(
21+
def collect_executorch_and_eager_outputs(
2422
self,
2523
eager_model: torch.nn.Module,
2624
example_inputs,
27-
validation_fn: Callable[[Any, Any], bool],
2825
):
2926
"""
30-
Asserts that the given model has the same result as the eager mode
31-
lowered model, with example_inputs, validated by validation_fn, which
32-
takes the eager mode output and ET output, and returns True if they
33-
match.
27+
Compares the output of the given eager mode PyTorch model with the output
28+
of the equivalent executorch model, both provided with example inputs.
29+
Returns a tuple containing the outputs of the eager mode model and the executorch mode model.
3430
"""
3531
eager_model = eager_model.eval()
3632
model = torch._export.capture_pre_autograd_graph(eager_model, example_inputs)
@@ -45,100 +41,105 @@ def _assert_eager_lowered_same_result(
4541
with torch.no_grad():
4642
executorch_output = pte_model.run_method("forward", example_inputs)
4743

48-
self.assertTrue(validation_fn(eager_output, executorch_output))
44+
return (eager_output, executorch_output)
4945

50-
@staticmethod
51-
def validate_tensor_allclose(eager_output, executorch_output, rtol=1e-5, atol=1e-5):
52-
result = torch.allclose(
53-
eager_output,
54-
executorch_output[0],
55-
rtol=rtol,
56-
atol=atol,
57-
)
58-
if not result:
59-
print(f"eager output: {eager_output}")
60-
print(f"executorch output: {executorch_output}")
61-
return result
46+
def validate_tensor_allclose(
47+
self, eager_output, executorch_output, rtol=1e-5, atol=1e-5
48+
):
49+
self.assertTrue(
50+
isinstance(eager_output, type(executorch_output)),
51+
f"Outputs are not of the same type: eager type: {type(eager_output)}, executorch type: {type(executorch_output)}",
52+
)
53+
self.assertTrue(
54+
len(eager_output) == len(executorch_output),
55+
f"len(eager_output)={len(eager_output)}, len(executorch_output)={len(executorch_output)}",
56+
)
57+
for i in range(len(eager_output)):
58+
result = torch.allclose(
59+
eager_output[i],
60+
executorch_output[i],
61+
rtol=rtol,
62+
atol=atol,
63+
)
64+
if not result:
65+
print(f"eager output[{i}]: {eager_output[i]}")
66+
print(f"executorch output[{i}]: {executorch_output[i]}")
67+
return self.assertTrue(result)
6268

6369
def test_mv3_export_to_executorch(self):
6470
eager_model, example_inputs, _ = EagerModelFactory.create_model(
6571
*MODEL_NAME_TO_MODEL["mv3"]
6672
)
67-
eager_model = eager_model.eval()
68-
73+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
74+
eager_model, example_inputs
75+
)
6976
# TODO(T166083470): Fix accuracy issue
70-
self._assert_eager_lowered_same_result(
71-
eager_model,
72-
example_inputs,
73-
lambda x, y: self.validate_tensor_allclose(x, y, rtol=1e-3, atol=1e-5),
77+
self.validate_tensor_allclose(
78+
eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
7479
)
7580

7681
def test_mv2_export_to_executorch(self):
7782
eager_model, example_inputs, _ = EagerModelFactory.create_model(
7883
*MODEL_NAME_TO_MODEL["mv2"]
7984
)
80-
eager_model = eager_model.eval()
81-
82-
self._assert_eager_lowered_same_result(
83-
eager_model, example_inputs, self.validate_tensor_allclose
85+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
86+
eager_model, example_inputs
8487
)
88+
self.validate_tensor_allclose(eager_output, executorch_output[0])
8589

8690
def test_vit_export_to_executorch(self):
8791
eager_model, example_inputs, _ = EagerModelFactory.create_model(
8892
*MODEL_NAME_TO_MODEL["vit"]
8993
)
90-
eager_model = eager_model.eval()
91-
92-
self._assert_eager_lowered_same_result(
93-
eager_model, example_inputs, self.validate_tensor_allclose
94+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
95+
eager_model, example_inputs
9496
)
97+
self.validate_tensor_allclose(eager_output, executorch_output[0])
9598

9699
def test_w2l_export_to_executorch(self):
97100
eager_model, example_inputs, _ = EagerModelFactory.create_model(
98101
*MODEL_NAME_TO_MODEL["w2l"]
99102
)
100-
eager_model = eager_model.eval()
101-
102-
self._assert_eager_lowered_same_result(
103-
eager_model, example_inputs, self.validate_tensor_allclose
103+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
104+
eager_model, example_inputs
104105
)
106+
self.validate_tensor_allclose(eager_output, executorch_output[0])
105107

106108
def test_ic3_export_to_executorch(self):
107109
eager_model, example_inputs, _ = EagerModelFactory.create_model(
108110
*MODEL_NAME_TO_MODEL["ic3"]
109111
)
110-
eager_model = eager_model.eval()
111-
112-
self._assert_eager_lowered_same_result(
113-
eager_model, example_inputs, self.validate_tensor_allclose
112+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
113+
eager_model, example_inputs
114+
)
115+
# TODO(T166083470): Fix accuracy issue
116+
self.validate_tensor_allclose(
117+
eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
114118
)
115119

116120
def test_resnet18_export_to_executorch(self):
117121
eager_model, example_inputs, _ = EagerModelFactory.create_model(
118122
*MODEL_NAME_TO_MODEL["resnet18"]
119123
)
120-
eager_model = eager_model.eval()
121-
122-
self._assert_eager_lowered_same_result(
123-
eager_model, example_inputs, self.validate_tensor_allclose
124+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
125+
eager_model, example_inputs
124126
)
127+
self.validate_tensor_allclose(eager_output, executorch_output[0])
125128

126129
def test_resnet50_export_to_executorch(self):
127130
eager_model, example_inputs, _ = EagerModelFactory.create_model(
128131
*MODEL_NAME_TO_MODEL["resnet50"]
129132
)
130-
eager_model = eager_model.eval()
131-
132-
self._assert_eager_lowered_same_result(
133-
eager_model, example_inputs, self.validate_tensor_allclose
133+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
134+
eager_model, example_inputs
134135
)
136+
self.validate_tensor_allclose(eager_output, executorch_output[0])
135137

136138
def test_dl3_export_to_executorch(self):
137139
eager_model, example_inputs, _ = EagerModelFactory.create_model(
138140
*MODEL_NAME_TO_MODEL["dl3"]
139141
)
140-
eager_model = eager_model.eval()
141-
142-
self._assert_eager_lowered_same_result(
143-
eager_model, example_inputs, self.validate_tensor_allclose
142+
eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
143+
eager_model, example_inputs
144144
)
145+
self.validate_tensor_allclose(list(eager_output.values()), executorch_output)

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ addopts =
3838
backends/arm/test
3939
# test
4040
test/end2end/test_end2end.py
41+
# examples/
42+
examples/portable/test/test_export.py
4143

4244
# run the same tests multiple times to determine their
4345
# flakiness status. Default to 50 re-runs

0 commit comments

Comments
 (0)