6
6
7
7
import unittest
8
8
9
- from typing import Any , Callable
10
-
11
9
import torch
12
10
from executorch .examples .models import MODEL_NAME_TO_MODEL
13
11
from executorch .examples .models .model_factory import EagerModelFactory
14
12
15
- from executorch .examples .portable .utils import export_to_edge
16
-
17
13
from executorch .extension .pybindings .portable_lib import ( # @manual
18
14
_load_for_executorch_from_buffer ,
19
15
)
20
16
17
+ from ..utils import export_to_edge
18
+
21
19
22
20
class ExportTest (unittest .TestCase ):
23
- def _assert_eager_lowered_same_result (
21
+ def collect_executorch_and_eager_outputs (
24
22
self ,
25
23
eager_model : torch .nn .Module ,
26
24
example_inputs ,
27
- validation_fn : Callable [[Any , Any ], bool ],
28
25
):
29
26
"""
30
- Asserts that the given model has the same result as the eager mode
31
- lowered model, with example_inputs, validated by validation_fn, which
32
- takes the eager mode output and ET output, and returns True if they
33
- match.
27
+ Compares the output of the given eager mode PyTorch model with the output
28
+ of the equivalent executorch model, both provided with example inputs.
29
+ Returns a tuple containing the outputs of the eager mode model and the executorch mode model.
34
30
"""
35
31
eager_model = eager_model .eval ()
36
32
model = torch ._export .capture_pre_autograd_graph (eager_model , example_inputs )
@@ -45,100 +41,105 @@ def _assert_eager_lowered_same_result(
45
41
with torch .no_grad ():
46
42
executorch_output = pte_model .run_method ("forward" , example_inputs )
47
43
48
- self . assertTrue ( validation_fn ( eager_output , executorch_output ) )
44
+ return ( eager_output , executorch_output )
49
45
50
- @staticmethod
51
- def validate_tensor_allclose (eager_output , executorch_output , rtol = 1e-5 , atol = 1e-5 ):
52
- result = torch .allclose (
53
- eager_output ,
54
- executorch_output [0 ],
55
- rtol = rtol ,
56
- atol = atol ,
57
- )
58
- if not result :
59
- print (f"eager output: { eager_output } " )
60
- print (f"executorch output: { executorch_output } " )
61
- return result
46
+ def validate_tensor_allclose (
47
+ self , eager_output , executorch_output , rtol = 1e-5 , atol = 1e-5
48
+ ):
49
+ self .assertTrue (
50
+ isinstance (eager_output , type (executorch_output )),
51
+ f"Outputs are not of the same type: eager type: { type (eager_output )} , executorch type: { type (executorch_output )} " ,
52
+ )
53
+ self .assertTrue (
54
+ len (eager_output ) == len (executorch_output ),
55
+ f"len(eager_output)={ len (eager_output )} , len(executorch_output)={ len (executorch_output )} " ,
56
+ )
57
+ for i in range (len (eager_output )):
58
+ result = torch .allclose (
59
+ eager_output [i ],
60
+ executorch_output [i ],
61
+ rtol = rtol ,
62
+ atol = atol ,
63
+ )
64
+ if not result :
65
+ print (f"eager output[{ i } ]: { eager_output [i ]} " )
66
+ print (f"executorch output[{ i } ]: { executorch_output [i ]} " )
67
+ return self .assertTrue (result )
62
68
63
69
def test_mv3_export_to_executorch (self ):
64
70
eager_model , example_inputs , _ = EagerModelFactory .create_model (
65
71
* MODEL_NAME_TO_MODEL ["mv3" ]
66
72
)
67
- eager_model = eager_model .eval ()
68
-
73
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
74
+ eager_model , example_inputs
75
+ )
69
76
# TODO(T166083470): Fix accuracy issue
70
- self ._assert_eager_lowered_same_result (
71
- eager_model ,
72
- example_inputs ,
73
- lambda x , y : self .validate_tensor_allclose (x , y , rtol = 1e-3 , atol = 1e-5 ),
77
+ self .validate_tensor_allclose (
78
+ eager_output , executorch_output [0 ], rtol = 1e-3 , atol = 1e-5
74
79
)
75
80
76
81
def test_mv2_export_to_executorch (self ):
77
82
eager_model , example_inputs , _ = EagerModelFactory .create_model (
78
83
* MODEL_NAME_TO_MODEL ["mv2" ]
79
84
)
80
- eager_model = eager_model .eval ()
81
-
82
- self ._assert_eager_lowered_same_result (
83
- eager_model , example_inputs , self .validate_tensor_allclose
85
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
86
+ eager_model , example_inputs
84
87
)
88
+ self .validate_tensor_allclose (eager_output , executorch_output [0 ])
85
89
86
90
def test_vit_export_to_executorch (self ):
87
91
eager_model , example_inputs , _ = EagerModelFactory .create_model (
88
92
* MODEL_NAME_TO_MODEL ["vit" ]
89
93
)
90
- eager_model = eager_model .eval ()
91
-
92
- self ._assert_eager_lowered_same_result (
93
- eager_model , example_inputs , self .validate_tensor_allclose
94
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
95
+ eager_model , example_inputs
94
96
)
97
+ self .validate_tensor_allclose (eager_output , executorch_output [0 ])
95
98
96
99
def test_w2l_export_to_executorch (self ):
97
100
eager_model , example_inputs , _ = EagerModelFactory .create_model (
98
101
* MODEL_NAME_TO_MODEL ["w2l" ]
99
102
)
100
- eager_model = eager_model .eval ()
101
-
102
- self ._assert_eager_lowered_same_result (
103
- eager_model , example_inputs , self .validate_tensor_allclose
103
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
104
+ eager_model , example_inputs
104
105
)
106
+ self .validate_tensor_allclose (eager_output , executorch_output [0 ])
105
107
106
108
def test_ic3_export_to_executorch (self ):
107
109
eager_model , example_inputs , _ = EagerModelFactory .create_model (
108
110
* MODEL_NAME_TO_MODEL ["ic3" ]
109
111
)
110
- eager_model = eager_model .eval ()
111
-
112
- self ._assert_eager_lowered_same_result (
113
- eager_model , example_inputs , self .validate_tensor_allclose
112
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
113
+ eager_model , example_inputs
114
+ )
115
+ # TODO(T166083470): Fix accuracy issue
116
+ self .validate_tensor_allclose (
117
+ eager_output , executorch_output [0 ], rtol = 1e-3 , atol = 1e-5
114
118
)
115
119
116
120
def test_resnet18_export_to_executorch (self ):
117
121
eager_model , example_inputs , _ = EagerModelFactory .create_model (
118
122
* MODEL_NAME_TO_MODEL ["resnet18" ]
119
123
)
120
- eager_model = eager_model .eval ()
121
-
122
- self ._assert_eager_lowered_same_result (
123
- eager_model , example_inputs , self .validate_tensor_allclose
124
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
125
+ eager_model , example_inputs
124
126
)
127
+ self .validate_tensor_allclose (eager_output , executorch_output [0 ])
125
128
126
129
def test_resnet50_export_to_executorch (self ):
127
130
eager_model , example_inputs , _ = EagerModelFactory .create_model (
128
131
* MODEL_NAME_TO_MODEL ["resnet50" ]
129
132
)
130
- eager_model = eager_model .eval ()
131
-
132
- self ._assert_eager_lowered_same_result (
133
- eager_model , example_inputs , self .validate_tensor_allclose
133
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
134
+ eager_model , example_inputs
134
135
)
136
+ self .validate_tensor_allclose (eager_output , executorch_output [0 ])
135
137
136
138
def test_dl3_export_to_executorch (self ):
137
139
eager_model , example_inputs , _ = EagerModelFactory .create_model (
138
140
* MODEL_NAME_TO_MODEL ["dl3" ]
139
141
)
140
- eager_model = eager_model .eval ()
141
-
142
- self ._assert_eager_lowered_same_result (
143
- eager_model , example_inputs , self .validate_tensor_allclose
142
+ eager_output , executorch_output = self .collect_executorch_and_eager_outputs (
143
+ eager_model , example_inputs
144
144
)
145
+ self .validate_tensor_allclose (list (eager_output .values ()), executorch_output )
0 commit comments