14
14
import torch
15
15
from executorch .backends .cadence .aot import compiler
16
16
from executorch .backends .cadence .aot .fuse_ops import (
17
+ FuseCascadedTransposeOrPermuteOps ,
18
+ FuseCascadedViewOps ,
17
19
FuseFullThenReshapePass ,
20
+ FuseMMWithAdd ,
18
21
FuseMulScalarIntoDequantPass ,
19
22
FuseMulTensorIntoDequantPass ,
20
23
FuseQuantDequantToRequantizePass ,
@@ -39,113 +42,133 @@ def check_op_counts(
39
42
40
43
41
44
class TestFusionPasses (TestFusionPassesBase ):
42
- def test_addmm_fusion (self ):
43
- class AddmmFeasible1 (torch .nn .Module ):
44
- def forward (self , x , y , z ):
45
- t1 = torch .mm (x , y )
46
- return torch .add (t1 , z )
47
-
48
- x = torch .randn (3 , 5 )
49
- y = torch .randn (5 , 6 )
50
- z = torch .randn (6 )
51
-
52
- graph_module = (
53
- compiler .export_to_cadence (AddmmFeasible1 (), (x , y , z ))
54
- .exported_program ()
55
- .graph_module
45
+ def test_fuse_mm_with_add (self ):
46
+ builder = GraphBuilder ()
47
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
48
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
49
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
50
+ mm = builder .call_operator (
51
+ op = exir_ops .edge .aten .mm .default ,
52
+ args = (x , y ),
53
+ )
54
+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
55
+ builder .output ([output ])
56
+ original_graph = builder .get_graph_module ()
57
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
58
+ converted_graph .graph .eliminate_dead_code ()
59
+ self .assertEqual (
60
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
56
61
)
57
- graph_module .graph .eliminate_dead_code ()
58
-
59
- # Assert that mm and add were fused to addmm
60
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
61
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
62
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
63
-
64
- class AddmmFeasible2 (torch .nn .Module ):
65
- def forward (self , x , y , z ):
66
- t1 = y .view ((8 , 6 ))
67
- t2 = torch .mm (x , t1 )
68
- t3 = t2 .view ((2 , 2 , 6 ))
69
- return torch .add (t3 , z )
70
-
71
- x = torch .randn (4 , 8 )
72
- y = torch .randn (2 , 4 , 6 )
73
- z = torch .randn (6 )
62
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
63
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
74
64
75
- graph_module = (
76
- compiler .export_to_cadence (AddmmFeasible2 (), (x , y , z ))
77
- .exported_program ()
78
- .graph_module
65
+ def test_fuse_view_mm_view_add (self ):
66
+ builder = GraphBuilder ()
67
+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
68
+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
69
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
70
+ y_view = builder .call_operator (
71
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
79
72
)
80
- graph_module .graph .eliminate_dead_code ()
81
- # Assert that mm and add were fused to addmm
82
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
83
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
84
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
85
-
86
- # Bias is a singleton value, broadcastable to output of mm
87
- class AddmmFeasible3 (torch .nn .Module ):
88
- def forward (self , x , y ):
89
- t1 = torch .mm (x , y )
90
- return torch .add (t1 , torch .ones (1 ))
91
-
92
- x = torch .randn (3 , 5 )
93
- y = torch .randn (5 , 6 )
94
-
95
- graph_module = (
96
- compiler .export_to_cadence (AddmmFeasible3 (), (x , y ))
97
- .exported_program ()
98
- .graph_module
73
+ mm = builder .call_operator (
74
+ op = exir_ops .edge .aten .mm .default ,
75
+ args = (x , y_view ),
76
+ )
77
+ mm_view = builder .call_operator (
78
+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
99
79
)
100
- graph_module .graph .eliminate_dead_code ()
101
- # Assert that mm and add were fused to addmm
102
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
103
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
104
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
80
+ output = builder .call_operator (
81
+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
82
+ )
83
+ builder .output ([output ])
84
+ original_graph = builder .get_graph_module ()
85
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
86
+ converted_graph .graph .eliminate_dead_code ()
87
+ self .assertEqual (
88
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
89
+ )
90
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
91
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
105
92
93
+ def test_keep_view_mm_view_add (self ):
94
+ builder = GraphBuilder ()
95
+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
96
+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
106
97
# Bias is not broadcastable to output of mm
107
- class AddmmInfeasible1 (torch .nn .Module ):
108
- def forward (self , x , y , z ):
109
- t1 = y .view ((8 , 6 ))
110
- t2 = torch .mm (x , t1 )
111
- t3 = t2 .view ((2 , 2 , 6 ))
112
- return torch .add (t3 , z )
113
-
114
- x = torch .randn (4 , 8 )
115
- y = torch .randn (2 , 4 , 6 )
116
- z = torch .randn (2 , 2 , 1 )
117
-
118
- graph_module = (
119
- compiler .export_to_cadence (AddmmInfeasible1 (), (x , y , z ))
120
- .exported_program ()
121
- .graph_module
98
+ z = builder .placeholder ("z" , torch .randn (2 , 2 , 1 , dtype = torch .float32 ))
99
+ y_view = builder .call_operator (
100
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
101
+ )
102
+ mm = builder .call_operator (
103
+ op = exir_ops .edge .aten .mm .default ,
104
+ args = (x , y_view ),
122
105
)
123
- graph_module .graph .eliminate_dead_code ()
106
+ mm_view = builder .call_operator (
107
+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
108
+ )
109
+ output = builder .call_operator (
110
+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
111
+ )
112
+ builder .output ([output ])
113
+ original_graph = builder .get_graph_module ()
114
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
115
+ converted_graph .graph .eliminate_dead_code ()
124
116
# Assert that mm and add were not fused to addmm, since z cannot be
125
117
# broadcasted to the out of mm.
126
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 1 )
127
-
128
- # The add consuming the output of mm has more than one users.
129
- class AddmmInfeasible2 (torch .nn .Module ):
130
- def forward (self , x , y , z ):
131
- t1 = torch .mm (x , y )
132
- t2 = torch .add (t1 , z )
133
- t3 = torch .add (t2 , z )
134
- return torch .add (t2 , t3 )
118
+ self .assertEqual (
119
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
120
+ )
121
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
122
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
135
123
136
- x = torch .randn (3 , 5 )
137
- y = torch .randn (5 , 6 )
138
- z = torch .randn (6 )
124
+ def test_fuse_mm_add_with_bias (self ):
125
+ builder = GraphBuilder ()
126
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
127
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
128
+ mm = builder .call_operator (
129
+ op = exir_ops .edge .aten .mm .default ,
130
+ args = (x , y ),
131
+ )
132
+ bias = builder .call_operator (op = exir_ops .edge .aten .full .default , args = ([1 ], 1 ))
133
+ output = builder .call_operator (
134
+ op = exir_ops .edge .aten .add .Tensor , args = (mm , bias )
135
+ )
136
+ builder .output ([output ])
137
+ original_graph = builder .get_graph_module ()
138
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
139
+ converted_graph .graph .eliminate_dead_code ()
140
+ self .assertEqual (
141
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
142
+ )
143
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
144
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
139
145
140
- graph_module = (
141
- compiler .export_to_cadence (AddmmInfeasible2 (), (x , y , z ))
142
- .exported_program ()
143
- .graph_module
146
+ def test_keep_mm_add_with_multiple_users (self ):
147
+ builder = GraphBuilder ()
148
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
149
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
150
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
151
+ mm = builder .call_operator (
152
+ op = exir_ops .edge .aten .mm .default ,
153
+ args = (x , y ),
154
+ )
155
+ # The add consuming the output of mm has more than one users.
156
+ add1 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
157
+ add2 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (add1 , z ))
158
+ output = builder .call_operator (
159
+ op = exir_ops .edge .aten .add .Tensor , args = (add1 , add2 )
144
160
)
145
- graph_module .graph .eliminate_dead_code ()
161
+ builder .output ([output ])
162
+ original_graph = builder .get_graph_module ()
163
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
164
+ converted_graph .graph .eliminate_dead_code ()
146
165
# Assert that mm and add were not fused to addmm, since add has multiple
147
166
# users.
148
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 3 )
167
+ self .assertEqual (
168
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
169
+ )
170
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
171
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
149
172
150
173
# TODO(matthiascremon): enable that pass with new flow
151
174
@torch .no_grad ()
@@ -184,63 +207,69 @@ def forward(self, x):
184
207
)
185
208
186
209
def test_permute_transpose_fusion (self ):
187
- class PermuteTranspose (torch .nn .Module ):
188
- def forward (self , x ):
189
- y = x .permute ((0 , 2 , 4 , 1 , 3 ))
190
- return y .transpose (0 , 1 )
191
-
192
- x = torch .randn (3 , 1 , 3 , 1 , 4 )
193
- graph_module = (
194
- compiler .export_to_cadence (PermuteTranspose (), (x ,))
195
- .exported_program ()
196
- .graph_module
210
+ builder = GraphBuilder ()
211
+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
212
+ permute = builder .call_operator (
213
+ op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 4 , 1 , 3 ])
214
+ )
215
+ output = builder .call_operator (
216
+ op = exir_ops .edge .aten .transpose_copy .int ,
217
+ args = (permute , 1 , 0 ),
197
218
)
198
- graph_module .graph .eliminate_dead_code ()
219
+ builder .output (output )
220
+ original_graph = builder .get_graph_module ()
221
+ converted_graph = FuseCascadedTransposeOrPermuteOps ()(
222
+ original_graph
223
+ ).graph_module
224
+ converted_graph .graph .eliminate_dead_code ()
199
225
# Assert that permute op was fused with transpose op
200
226
self .assertEqual (
201
- count_node (graph_module , exir_ops .edge .aten .permute_copy .default ), 1
227
+ count_node (converted_graph , exir_ops .edge .aten .permute_copy .default ), 1
202
228
)
203
229
self .assertEqual (
204
- count_node (graph_module , exir_ops .edge .aten .transpose_copy .int ), 0
230
+ count_node (converted_graph , exir_ops .edge .aten .transpose_copy .int ), 0
205
231
)
206
232
207
233
def test_view_fusion (self ):
208
- class ViewFusion (torch .nn .Module ):
209
- def forward (self , x ):
210
- x = x .view ([1 , 8 , 15 ])
211
- x = x .view ([1 , 1 , 120 ])
212
- return x .view ([1 , 12 , 10 ])
213
-
214
- x = torch .randn (8 , 5 , 3 )
215
- graph_module = (
216
- compiler .export_to_cadence (ViewFusion (), (x ,))
217
- .exported_program ()
218
- .graph_module
234
+ builder = GraphBuilder ()
235
+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
236
+ view1 = builder .call_operator (
237
+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
238
+ )
239
+ view2 = builder .call_operator (
240
+ op = exir_ops .edge .aten .view_copy .default , args = (view1 , [1 , 1 , 120 ])
241
+ )
242
+ output = builder .call_operator (
243
+ op = exir_ops .edge .aten .view_copy .default , args = (view2 , [1 , 12 , 10 ])
219
244
)
220
- graph_module .graph .eliminate_dead_code ()
245
+ builder .output (output )
246
+ original_graph = builder .get_graph_module ()
247
+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
248
+ converted_graph .graph .eliminate_dead_code ()
221
249
# Assert that only one view op remains
222
250
self .assertEqual (
223
- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 1
251
+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 1
224
252
)
225
253
226
254
def test_view_fusion_branched (self ):
227
- class ViewFusion (torch .nn .Module ):
228
- def forward (self , x ):
229
- y = x .view ([1 , 8 , 15 ])
230
- z = y .view ([1 , 1 , 120 ])
231
- t = y .view ([120 , 1 , 1 ])
232
- return z , t
233
-
234
- x = torch .randn (8 , 5 , 3 )
235
- graph_module = (
236
- compiler .export_to_cadence (ViewFusion (), (x ,))
237
- .exported_program ()
238
- .graph_module
255
+ builder = GraphBuilder ()
256
+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
257
+ y = builder .call_operator (
258
+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
259
+ )
260
+ z = builder .call_operator (
261
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [1 , 1 , 120 ])
239
262
)
240
- graph_module .graph .eliminate_dead_code ()
263
+ t = builder .call_operator (
264
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [120 , 1 , 1 ])
265
+ )
266
+ builder .output ([z , t ])
267
+ original_graph = builder .get_graph_module ()
268
+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
269
+ converted_graph .graph .eliminate_dead_code ()
241
270
# z and t should be fused and y should be eliminated.
242
271
self .assertEqual (
243
- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 2
272
+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 2
244
273
)
245
274
246
275
def test_force_quant_dequant_fusion (self ):
0 commit comments