@@ -36,39 +36,45 @@ def forward(self, arg1, arg2, arg3, arg4, arg5):
36
36
return x + x # Quantize by propagation.
37
37
38
38
def _test_cat (self , module , inputs , cat_num = 1 , quant = False , quant_ops = 2 ):
39
- tester = Tester (module , inputs )
40
-
41
- if quant :
42
- tester .quantize ()
43
-
44
- tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
45
- tester .dump_artifact ()
46
-
47
- if quant :
48
- # Expect multiple quantize ops - one per input, cat, and add.
49
- tester .check_node_count (
50
- {
51
- # Q/DQ pair for each input and quantized op. For most tests, there are
52
- # two quantized ops - cat and add.
53
- torch .ops .quantized_decomposed .quantize_per_tensor .default : (
54
- cat_num + quant_ops
55
- )
56
- }
39
+ for legacy_mode in (True , False ):
40
+ tester = Tester (module , inputs )
41
+
42
+ if quant :
43
+ tester .quantize ()
44
+
45
+ tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
46
+ tester .dump_artifact ()
47
+
48
+ if quant :
49
+ # Expect multiple quantize ops - one per input, cat, and add.
50
+ tester .check_node_count (
51
+ {
52
+ # Q/DQ pair for each input and quantized op. For most tests, there are
53
+ # two quantized ops - cat and add.
54
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : (
55
+ cat_num + quant_ops
56
+ )
57
+ }
58
+ )
59
+
60
+
61
+ if legacy_mode :
62
+ tester .to_edge ()
63
+ tester .partition ()
64
+ else :
65
+ tester .to_edge_transform_and_lower ()
66
+
67
+ if quant :
68
+ tester .check_not (["torch.ops.quantized_decomposed" ])
69
+
70
+ (
71
+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
72
+ .check_not (["executorch_exir_dialects_edge__ops_aten_cat" ])
73
+ .to_executorch ()
74
+ .serialize ()
75
+ .run_method_and_compare_outputs ()
57
76
)
58
77
59
- tester .to_edge_transform_and_lower ()
60
-
61
- if quant :
62
- tester .check_not (["torch.ops.quantized_decomposed" ])
63
-
64
- (
65
- tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
66
- .check_not (["executorch_exir_dialects_edge__ops_aten_cat" ])
67
- .to_executorch ()
68
- .serialize ()
69
- .run_method_and_compare_outputs ()
70
- )
71
-
72
78
def test_fp16_cat2 (self ):
73
79
"""
74
80
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
@@ -155,6 +161,26 @@ def test_fp32_cat_unsupported(self):
155
161
.check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
156
162
)
157
163
164
+ def test_fp32_cat_unsupported_legacy_mode (self ):
165
+ """
166
+ XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
167
+ """
168
+ inputs = (
169
+ torch .randn (1 , 2 , 3 ),
170
+ torch .randn (3 , 2 , 3 ),
171
+ torch .randn (2 , 2 , 3 ),
172
+ torch .randn (5 , 2 , 3 ),
173
+ torch .randn (1 , 2 , 3 ),
174
+ )
175
+ (
176
+ Tester (self .Cat5 (), inputs )
177
+ .export ()
178
+ .check_count ({"torch.ops.aten.cat" : 1 })
179
+ .to_edge ()
180
+ .partition ()
181
+ .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
182
+ )
183
+
158
184
class CatNegativeDim (torch .nn .Module ):
159
185
def __init__ (self ):
160
186
super ().__init__ ()
0 commit comments