3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
- import unittest
6
+ from typing import Tuple
7
+
8
+ import pytest
7
9
8
10
import torch
9
- from executorch .backends .arm .test import common , conftest
10
11
11
- from executorch .backends .arm .test .tester .arm_tester import ArmTester
12
+ from executorch .backends .arm .test import common
13
+ from executorch .backends .arm .test .tester .test_pipeline import (
14
+ EthosU55PipelineBI ,
15
+ EthosU85PipelineBI ,
16
+ TosaPipelineBI ,
17
+ TosaPipelineMI ,
18
+ )
12
19
13
20
from torchaudio .models import Conformer
14
21
22
+ input_t = Tuple [torch .Tensor , torch .IntTensor ] # Input x, y
23
+
15
24
16
25
def get_test_inputs (dim , lengths , num_examples ):
17
26
return (torch .rand (num_examples , int (lengths .max ()), dim ), lengths )
18
27
19
28
20
- class TestConformer ( unittest . TestCase ) :
29
+ class TestConformer :
21
30
"""Tests Torchaudio Conformer"""
22
31
23
32
# Adjust nbr below as we increase op support. Note: most of the delegates
24
33
# calls are directly consecutive to each other in the .pte. The reason
25
34
# for that is some assert ops are removed by passes in the
26
35
# .to_executorch step, i.e. after Arm partitioner.
27
- ops_after_partitioner = {
28
- "executorch_exir_dialects_edge__ops_aten_max_default" : 1 ,
29
- "torch.ops.aten._assert_scalar.default" : 7 ,
30
- "torch.ops.aten._local_scalar_dense.default" : 1 ,
31
- }
36
+ aten_ops = ["torch.ops.aten._assert_scalar.default" ]
32
37
33
38
dim = 16
34
39
num_examples = 10
@@ -43,96 +48,87 @@ class TestConformer(unittest.TestCase):
43
48
)
44
49
conformer = conformer .eval ()
45
50
46
- def test_conformer_tosa_MI (self ):
47
- (
48
- ArmTester (
49
- self .conformer ,
50
- example_inputs = self .model_example_inputs ,
51
- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+MI" ),
52
- )
53
- .export ()
54
- .to_edge_transform_and_lower ()
55
- .dump_operator_distribution ()
56
- .check_count (self .ops_after_partitioner )
57
- .to_executorch ()
58
- # TODO(MLETORCH-632): Fix numerical errors
59
- .run_method_and_compare_outputs (
60
- rtol = 1.0 ,
61
- atol = 5.0 ,
62
- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
63
- )
64
- )
65
-
66
- def test_conformer_tosa_BI (self ):
67
- (
68
- ArmTester (
69
- self .conformer ,
70
- example_inputs = self .model_example_inputs ,
71
- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+BI" ),
72
- )
73
- .quantize ()
74
- .export ()
75
- .to_edge_transform_and_lower ()
76
- .to_executorch ()
77
- .run_method_and_compare_outputs (
78
- qtol = 1.0 ,
79
- rtol = 1.0 ,
80
- atol = 5.0 ,
81
- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
82
- )
83
- )
84
-
85
- def test_conformer_u55_BI (self ):
86
- tester = (
87
- ArmTester (
88
- self .conformer ,
89
- example_inputs = self .model_example_inputs ,
90
- compile_spec = common .get_u55_compile_spec (),
91
- )
92
- .quantize ()
93
- .export ()
94
- .to_edge_transform_and_lower ()
95
- .to_executorch ()
96
- .serialize ()
97
- )
98
-
99
- if conftest .is_option_enabled ("corstone_fvp" ):
100
- try :
101
- tester .run_method_and_compare_outputs (
102
- qtol = 1.0 ,
103
- rtol = 1.0 ,
104
- atol = 5.0 ,
105
- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
106
- )
107
- self .fail (
108
- "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
109
- )
110
- except Exception :
111
- pass
112
-
113
- def test_conformer_u85_BI (self ):
114
- tester = (
115
- ArmTester (
116
- self .conformer ,
117
- example_inputs = self .model_example_inputs ,
118
- compile_spec = common .get_u85_compile_spec (),
119
- )
120
- .quantize ()
121
- .export ()
122
- .to_edge_transform_and_lower ()
123
- .to_executorch ()
124
- .serialize ()
125
- )
126
- if conftest .is_option_enabled ("corstone_fvp" ):
127
- try :
128
- tester .run_method_and_compare_outputs (
129
- qtol = 1.0 ,
130
- rtol = 1.0 ,
131
- atol = 5.0 ,
132
- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
133
- )
134
- self .fail (
135
- "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
136
- )
137
- except Exception :
138
- pass
51
+
52
+ def test_conformer_tosa_MI ():
53
+ pipeline = TosaPipelineMI [input_t ](
54
+ TestConformer .conformer ,
55
+ TestConformer .model_example_inputs ,
56
+ aten_op = TestConformer .aten_ops ,
57
+ exir_op = [],
58
+ use_to_edge_transform_and_lower = True ,
59
+ )
60
+ pipeline .change_args (
61
+ "run_method_and_compare_outputs" ,
62
+ get_test_inputs (
63
+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
64
+ ),
65
+ rtol = 1.0 ,
66
+ atol = 5.0 ,
67
+ )
68
+ pipeline .run ()
69
+
70
+
71
+ def test_conformer_tosa_BI ():
72
+ pipeline = TosaPipelineBI [input_t ](
73
+ TestConformer .conformer ,
74
+ TestConformer .model_example_inputs ,
75
+ aten_op = TestConformer .aten_ops ,
76
+ exir_op = [],
77
+ use_to_edge_transform_and_lower = True ,
78
+ )
79
+ pipeline .pop_stage ("check_count.exir" )
80
+ pipeline .change_args (
81
+ "run_method_and_compare_outputs" ,
82
+ get_test_inputs (
83
+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
84
+ ),
85
+ rtol = 1.0 ,
86
+ atol = 5.0 ,
87
+ )
88
+ pipeline .run ()
89
+
90
+
91
+ @common .XfailIfNoCorstone300
92
+ @pytest .mark .xfail (
93
+ reason = "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
94
+ )
95
+ def test_conformer_u55_BI ():
96
+ pipeline = EthosU55PipelineBI [input_t ](
97
+ TestConformer .conformer ,
98
+ TestConformer .model_example_inputs ,
99
+ aten_ops = TestConformer .aten_ops ,
100
+ exir_ops = [],
101
+ use_to_edge_transform_and_lower = True ,
102
+ run_on_fvp = True ,
103
+ )
104
+ pipeline .change_args (
105
+ "run_method_and_compare_outputs" ,
106
+ get_test_inputs (
107
+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
108
+ ),
109
+ rtol = 1.0 ,
110
+ atol = 5.0 ,
111
+ )
112
+ pipeline .run ()
113
+
114
+
115
+ @common .XfailIfNoCorstone320
116
+ @pytest .mark .xfail (reason = "All IO needs to have the same data type (MLETORCH-635)" )
117
+ def test_conformer_u85_BI ():
118
+ pipeline = EthosU85PipelineBI [input_t ](
119
+ TestConformer .conformer ,
120
+ TestConformer .model_example_inputs ,
121
+ aten_ops = TestConformer .aten_ops ,
122
+ exir_ops = [],
123
+ use_to_edge_transform_and_lower = True ,
124
+ run_on_fvp = True ,
125
+ )
126
+ pipeline .change_args (
127
+ "run_method_and_compare_outputs" ,
128
+ get_test_inputs (
129
+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
130
+ ),
131
+ rtol = 1.0 ,
132
+ atol = 5.0 ,
133
+ )
134
+ pipeline .run ()
0 commit comments