4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import logging
7
+ from typing import Tuple
7
8
8
9
import torch
9
10
from executorch .backends .arm .test import common
10
- from executorch .backends .arm .test .tester .arm_tester import ArmTester
11
- from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
11
+ from executorch .backends .arm .test .tester .test_pipeline import TosaPipelineMI
12
12
from executorch .exir .backend .operator_support import (
13
13
DontPartition ,
14
14
DontPartitionModule ,
15
15
DontPartitionName ,
16
16
)
17
17
from executorch .exir .dialects ._ops import ops as exir_ops
18
18
19
+ input_t1 = Tuple [torch .Tensor , torch .Tensor ] # Input x, y
20
+
19
21
20
22
class CustomPartitioning (torch .nn .Module ):
21
- inputs = (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 ))
23
+ inputs = {
24
+ "randn" : (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 )),
25
+ }
22
26
23
27
def forward (self , x : torch .Tensor , y : torch .Tensor ):
24
28
z = x + y
@@ -27,7 +31,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
27
31
28
32
29
33
class NestedModule (torch .nn .Module ):
30
- inputs = (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 ))
34
+ inputs = {
35
+ "randn" : (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 )),
36
+ }
31
37
32
38
def __init__ (self ):
33
39
super ().__init__ ()
@@ -39,192 +45,139 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
39
45
return self .nested (a , b )
40
46
41
47
42
- def test_single_reject (caplog ):
48
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
49
+ def test_single_reject (caplog , test_data : input_t1 ):
43
50
caplog .set_level (logging .INFO )
44
51
45
52
module = CustomPartitioning ()
46
- inputs = module .inputs
47
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
53
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
48
54
check = DontPartition (exir_ops .edge .aten .sigmoid .default )
49
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
50
- (
51
- ArmTester (
52
- module ,
53
- example_inputs = inputs ,
54
- compile_spec = compile_spec ,
55
- )
56
- .export ()
57
- .to_edge_transform_and_lower (partitioners = [partitioner ])
58
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
59
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
60
- .to_executorch ()
61
- .run_method_and_compare_outputs (inputs = inputs )
55
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
56
+ pipeline .change_args (
57
+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
58
+ )
59
+ pipeline .change_args (
60
+ "check_count.exir" ,
61
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
62
62
)
63
+ pipeline .run ()
63
64
assert check .has_rejected_node ()
64
65
assert "Rejected by DontPartition" in caplog .text
65
66
66
67
67
- def test_multiple_reject ():
68
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
69
+ def test_multiple_reject (test_data : input_t1 ):
68
70
module = CustomPartitioning ()
69
- inputs = module .inputs
70
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
71
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
71
72
check = DontPartition (
72
73
exir_ops .edge .aten .sigmoid .default , exir_ops .edge .aten .mul .Tensor
73
74
)
74
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
75
- (
76
- ArmTester (
77
- module ,
78
- example_inputs = inputs ,
79
- compile_spec = compile_spec ,
80
- )
81
- .export ()
82
- .to_edge_transform_and_lower (partitioners = [partitioner ])
83
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
84
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
85
- .to_executorch ()
86
- .run_method_and_compare_outputs (inputs = inputs )
75
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
76
+ pipeline .change_args (
77
+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
78
+ )
79
+ pipeline .change_args (
80
+ "check_count.exir" ,
81
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
87
82
)
83
+ pipeline .run ()
88
84
assert check .has_rejected_node ()
89
85
90
86
91
- def test_torch_op_reject (caplog ):
87
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
88
+ def test_torch_op_reject (caplog , test_data : input_t1 ):
92
89
caplog .set_level (logging .INFO )
93
90
94
91
module = CustomPartitioning ()
95
- inputs = module .inputs
96
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
97
92
check = DontPartition (torch .ops .aten .sigmoid .default )
98
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
99
- (
100
- ArmTester (
101
- module ,
102
- example_inputs = inputs ,
103
- compile_spec = compile_spec ,
104
- )
105
- .export ()
106
- .to_edge_transform_and_lower (partitioners = [partitioner ])
107
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
108
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
109
- .to_executorch ()
110
- .run_method_and_compare_outputs (inputs = inputs )
93
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
94
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
95
+ pipeline .change_args (
96
+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
111
97
)
98
+ pipeline .change_args (
99
+ "check_count.exir" ,
100
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
101
+ )
102
+ pipeline .run ()
112
103
assert check .has_rejected_node ()
113
104
assert "Rejected by DontPartition" in caplog .text
114
105
115
106
116
- def test_string_op_reject ():
107
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
108
+ def test_string_op_reject (test_data : input_t1 ):
117
109
module = CustomPartitioning ()
118
- inputs = module .inputs
119
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
120
110
check = DontPartition ("aten.sigmoid.default" )
121
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
122
- (
123
- ArmTester (
124
- module ,
125
- example_inputs = inputs ,
126
- compile_spec = compile_spec ,
127
- )
128
- .export ()
129
- .to_edge_transform_and_lower (partitioners = [partitioner ])
130
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
131
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 2 })
132
- .to_executorch ()
133
- .run_method_and_compare_outputs (inputs = inputs )
111
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
112
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
113
+ pipeline .change_args (
114
+ "check_count.exir" , {"torch.ops.higher_order.executorch_call_delegate" : 2 }
134
115
)
135
-
116
+ pipeline .change_args (
117
+ "check_count.exir" ,
118
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
119
+ )
120
+ pipeline .run ()
136
121
assert check .has_rejected_node ()
137
122
138
123
139
- def test_name_reject (caplog ):
124
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
125
+ def test_name_reject (caplog , test_data : input_t1 ):
140
126
caplog .set_level (logging .INFO )
141
127
142
128
module = CustomPartitioning ()
143
- inputs = module .inputs
144
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
145
129
check = DontPartitionName ("mul" , "sigmoid" , exact = False )
146
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
147
- (
148
- ArmTester (
149
- module ,
150
- example_inputs = inputs ,
151
- compile_spec = compile_spec ,
152
- )
153
- .export ()
154
- .to_edge_transform_and_lower (partitioners = [partitioner ])
155
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
156
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
157
- .to_executorch ()
158
- .run_method_and_compare_outputs (inputs = inputs )
130
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
131
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
132
+ pipeline .change_args (
133
+ "check_count.exir" ,
134
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
159
135
)
136
+ pipeline .run ()
160
137
assert check .has_rejected_node ()
161
138
assert "Rejected by DontPartitionName" in caplog .text
162
139
163
140
164
- def test_module_reject ():
141
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
142
+ def test_module_reject (test_data : input_t1 ):
165
143
module = NestedModule ()
166
- inputs = module .inputs
167
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
168
144
check = DontPartitionModule (module_name = "CustomPartitioning" )
169
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
170
- (
171
- ArmTester (
172
- module ,
173
- example_inputs = inputs ,
174
- compile_spec = compile_spec ,
175
- )
176
- .export ()
177
- .to_edge_transform_and_lower (partitioners = [partitioner ])
178
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
179
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
180
- .to_executorch ()
181
- .run_method_and_compare_outputs (inputs = inputs )
145
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
146
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
147
+ pipeline .change_args (
148
+ "check_count.exir" ,
149
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
182
150
)
151
+ pipeline .run ()
183
152
assert check .has_rejected_node ()
184
153
185
154
186
- def test_inexact_module_reject (caplog ):
155
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
156
+ def test_inexact_module_reject (caplog , test_data : input_t1 ):
187
157
caplog .set_level (logging .INFO )
188
158
189
159
module = NestedModule ()
190
- inputs = module .inputs
191
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
192
160
check = DontPartitionModule (module_name = "Custom" , exact = False )
193
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
194
- (
195
- ArmTester (
196
- module ,
197
- example_inputs = inputs ,
198
- compile_spec = compile_spec ,
199
- )
200
- .export ()
201
- .to_edge_transform_and_lower (partitioners = [partitioner ])
202
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
203
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
204
- .to_executorch ()
205
- .run_method_and_compare_outputs (inputs = inputs )
161
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
162
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
163
+ pipeline .change_args (
164
+ "check_count.exir" ,
165
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
206
166
)
167
+ pipeline .run ()
207
168
assert check .has_rejected_node ()
208
169
assert "Rejected by DontPartitionModule" in caplog .text
209
170
210
171
211
- def test_module_instance_reject ():
172
+ @common .parametrize ("test_data" , CustomPartitioning .inputs )
173
+ def test_module_instance_reject (test_data : input_t1 ):
212
174
module = NestedModule ()
213
- inputs = module .inputs
214
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" )
215
175
check = DontPartitionModule (instance_name = "nested" )
216
- partitioner = TOSAPartitioner (compile_spec , additional_checks = [check ])
217
- (
218
- ArmTester (
219
- module ,
220
- example_inputs = inputs ,
221
- compile_spec = compile_spec ,
222
- )
223
- .export ()
224
- .to_edge_transform_and_lower (partitioners = [partitioner ])
225
- .check (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
226
- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
227
- .to_executorch ()
228
- .run_method_and_compare_outputs (inputs = inputs )
176
+ pipeline = TosaPipelineMI [input_t1 ](module , test_data , [], exir_op = [])
177
+ pipeline .change_args ("to_edge_transform_and_lower" , additional_checks = [check ])
178
+ pipeline .change_args (
179
+ "check_count.exir" ,
180
+ {"executorch_exir_dialects_edge__ops_aten_sigmoid_default" : 1 },
229
181
)
182
+ pipeline .run ()
230
183
assert check .has_rejected_node ()
0 commit comments