Skip to content

Commit fded4c9

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in test fusion ops. (#11078)
Summary: Pull Request resolved: #11078 Reviewed By: hsharma35 Differential Revision: D75183327
1 parent d8ac866 commit fded4c9

File tree

1 file changed

+162
-133
lines changed

1 file changed

+162
-133
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 162 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
1616
from executorch.backends.cadence.aot.fuse_ops import (
17+
FuseCascadedTransposeOrPermuteOps,
18+
FuseCascadedViewOps,
1719
FuseFullThenReshapePass,
20+
FuseMMWithAdd,
1821
FuseMulScalarIntoDequantPass,
1922
FuseMulTensorIntoDequantPass,
2023
FuseQuantDequantToRequantizePass,
@@ -39,113 +42,133 @@ def check_op_counts(
3942

4043

4144
class TestFusionPasses(TestFusionPassesBase):
42-
def test_addmm_fusion(self):
43-
class AddmmFeasible1(torch.nn.Module):
44-
def forward(self, x, y, z):
45-
t1 = torch.mm(x, y)
46-
return torch.add(t1, z)
47-
48-
x = torch.randn(3, 5)
49-
y = torch.randn(5, 6)
50-
z = torch.randn(6)
51-
52-
graph_module = (
53-
compiler.export_to_cadence(AddmmFeasible1(), (x, y, z))
54-
.exported_program()
55-
.graph_module
45+
def test_fuse_mm_with_add(self):
46+
builder = GraphBuilder()
47+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
48+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
49+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
50+
mm = builder.call_operator(
51+
op=exir_ops.edge.aten.mm.default,
52+
args=(x, y),
53+
)
54+
output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
55+
builder.output([output])
56+
original_graph = builder.get_graph_module()
57+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
58+
converted_graph.graph.eliminate_dead_code()
59+
self.assertEqual(
60+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
5661
)
57-
graph_module.graph.eliminate_dead_code()
58-
59-
# Assert that mm and add were fused to addmm
60-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
61-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
62-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
63-
64-
class AddmmFeasible2(torch.nn.Module):
65-
def forward(self, x, y, z):
66-
t1 = y.view((8, 6))
67-
t2 = torch.mm(x, t1)
68-
t3 = t2.view((2, 2, 6))
69-
return torch.add(t3, z)
70-
71-
x = torch.randn(4, 8)
72-
y = torch.randn(2, 4, 6)
73-
z = torch.randn(6)
62+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
63+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
7464

75-
graph_module = (
76-
compiler.export_to_cadence(AddmmFeasible2(), (x, y, z))
77-
.exported_program()
78-
.graph_module
65+
def test_fuse_view_mm_view_add(self):
66+
builder = GraphBuilder()
67+
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
68+
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
69+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
70+
y_view = builder.call_operator(
71+
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
7972
)
80-
graph_module.graph.eliminate_dead_code()
81-
# Assert that mm and add were fused to addmm
82-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
83-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
84-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
85-
86-
# Bias is a singleton value, broadcastable to output of mm
87-
class AddmmFeasible3(torch.nn.Module):
88-
def forward(self, x, y):
89-
t1 = torch.mm(x, y)
90-
return torch.add(t1, torch.ones(1))
91-
92-
x = torch.randn(3, 5)
93-
y = torch.randn(5, 6)
94-
95-
graph_module = (
96-
compiler.export_to_cadence(AddmmFeasible3(), (x, y))
97-
.exported_program()
98-
.graph_module
73+
mm = builder.call_operator(
74+
op=exir_ops.edge.aten.mm.default,
75+
args=(x, y_view),
76+
)
77+
mm_view = builder.call_operator(
78+
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
9979
)
100-
graph_module.graph.eliminate_dead_code()
101-
# Assert that mm and add were fused to addmm
102-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
103-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
104-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
80+
output = builder.call_operator(
81+
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
82+
)
83+
builder.output([output])
84+
original_graph = builder.get_graph_module()
85+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
86+
converted_graph.graph.eliminate_dead_code()
87+
self.assertEqual(
88+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
89+
)
90+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
91+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
10592

93+
def test_keep_view_mm_view_add(self):
94+
builder = GraphBuilder()
95+
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
96+
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
10697
# Bias is not broadcastable to output of mm
107-
class AddmmInfeasible1(torch.nn.Module):
108-
def forward(self, x, y, z):
109-
t1 = y.view((8, 6))
110-
t2 = torch.mm(x, t1)
111-
t3 = t2.view((2, 2, 6))
112-
return torch.add(t3, z)
113-
114-
x = torch.randn(4, 8)
115-
y = torch.randn(2, 4, 6)
116-
z = torch.randn(2, 2, 1)
117-
118-
graph_module = (
119-
compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z))
120-
.exported_program()
121-
.graph_module
98+
z = builder.placeholder("z", torch.randn(2, 2, 1, dtype=torch.float32))
99+
y_view = builder.call_operator(
100+
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
101+
)
102+
mm = builder.call_operator(
103+
op=exir_ops.edge.aten.mm.default,
104+
args=(x, y_view),
122105
)
123-
graph_module.graph.eliminate_dead_code()
106+
mm_view = builder.call_operator(
107+
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
108+
)
109+
output = builder.call_operator(
110+
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
111+
)
112+
builder.output([output])
113+
original_graph = builder.get_graph_module()
114+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
115+
converted_graph.graph.eliminate_dead_code()
124116
# Assert that mm and add were not fused to addmm, since z cannot be
125117
# broadcasted to the out of mm.
126-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1)
127-
128-
# The add consuming the output of mm has more than one users.
129-
class AddmmInfeasible2(torch.nn.Module):
130-
def forward(self, x, y, z):
131-
t1 = torch.mm(x, y)
132-
t2 = torch.add(t1, z)
133-
t3 = torch.add(t2, z)
134-
return torch.add(t2, t3)
118+
self.assertEqual(
119+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
120+
)
121+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
122+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1)
135123

136-
x = torch.randn(3, 5)
137-
y = torch.randn(5, 6)
138-
z = torch.randn(6)
124+
def test_fuse_mm_add_with_bias(self):
125+
builder = GraphBuilder()
126+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
127+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
128+
mm = builder.call_operator(
129+
op=exir_ops.edge.aten.mm.default,
130+
args=(x, y),
131+
)
132+
bias = builder.call_operator(op=exir_ops.edge.aten.full.default, args=([1], 1))
133+
output = builder.call_operator(
134+
op=exir_ops.edge.aten.add.Tensor, args=(mm, bias)
135+
)
136+
builder.output([output])
137+
original_graph = builder.get_graph_module()
138+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
139+
converted_graph.graph.eliminate_dead_code()
140+
self.assertEqual(
141+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
142+
)
143+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
144+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
139145

140-
graph_module = (
141-
compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z))
142-
.exported_program()
143-
.graph_module
146+
def test_keep_mm_add_with_multiple_users(self):
147+
builder = GraphBuilder()
148+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
149+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
150+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
151+
mm = builder.call_operator(
152+
op=exir_ops.edge.aten.mm.default,
153+
args=(x, y),
154+
)
155+
# The add consuming the output of mm has more than one users.
156+
add1 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
157+
add2 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(add1, z))
158+
output = builder.call_operator(
159+
op=exir_ops.edge.aten.add.Tensor, args=(add1, add2)
144160
)
145-
graph_module.graph.eliminate_dead_code()
161+
builder.output([output])
162+
original_graph = builder.get_graph_module()
163+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
164+
converted_graph.graph.eliminate_dead_code()
146165
# Assert that mm and add were not fused to addmm, since add has multiple
147166
# users.
148-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3)
167+
self.assertEqual(
168+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
169+
)
170+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
171+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
149172

150173
# TODO(matthiascremon): enable that pass with new flow
151174
@torch.no_grad()
@@ -184,63 +207,69 @@ def forward(self, x):
184207
)
185208

186209
def test_permute_transpose_fusion(self):
187-
class PermuteTranspose(torch.nn.Module):
188-
def forward(self, x):
189-
y = x.permute((0, 2, 4, 1, 3))
190-
return y.transpose(0, 1)
191-
192-
x = torch.randn(3, 1, 3, 1, 4)
193-
graph_module = (
194-
compiler.export_to_cadence(PermuteTranspose(), (x,))
195-
.exported_program()
196-
.graph_module
210+
builder = GraphBuilder()
211+
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))
212+
permute = builder.call_operator(
213+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3])
214+
)
215+
output = builder.call_operator(
216+
op=exir_ops.edge.aten.transpose_copy.int,
217+
args=(permute, 1, 0),
197218
)
198-
graph_module.graph.eliminate_dead_code()
219+
builder.output(output)
220+
original_graph = builder.get_graph_module()
221+
converted_graph = FuseCascadedTransposeOrPermuteOps()(
222+
original_graph
223+
).graph_module
224+
converted_graph.graph.eliminate_dead_code()
199225
# Assert that permute op was fused with transpose op
200226
self.assertEqual(
201-
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1
227+
count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1
202228
)
203229
self.assertEqual(
204-
count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0
230+
count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0
205231
)
206232

207233
def test_view_fusion(self):
208-
class ViewFusion(torch.nn.Module):
209-
def forward(self, x):
210-
x = x.view([1, 8, 15])
211-
x = x.view([1, 1, 120])
212-
return x.view([1, 12, 10])
213-
214-
x = torch.randn(8, 5, 3)
215-
graph_module = (
216-
compiler.export_to_cadence(ViewFusion(), (x,))
217-
.exported_program()
218-
.graph_module
234+
builder = GraphBuilder()
235+
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
236+
view1 = builder.call_operator(
237+
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
238+
)
239+
view2 = builder.call_operator(
240+
op=exir_ops.edge.aten.view_copy.default, args=(view1, [1, 1, 120])
241+
)
242+
output = builder.call_operator(
243+
op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10])
219244
)
220-
graph_module.graph.eliminate_dead_code()
245+
builder.output(output)
246+
original_graph = builder.get_graph_module()
247+
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
248+
converted_graph.graph.eliminate_dead_code()
221249
# Assert that only one view op remains
222250
self.assertEqual(
223-
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
251+
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1
224252
)
225253

226254
def test_view_fusion_branched(self):
227-
class ViewFusion(torch.nn.Module):
228-
def forward(self, x):
229-
y = x.view([1, 8, 15])
230-
z = y.view([1, 1, 120])
231-
t = y.view([120, 1, 1])
232-
return z, t
233-
234-
x = torch.randn(8, 5, 3)
235-
graph_module = (
236-
compiler.export_to_cadence(ViewFusion(), (x,))
237-
.exported_program()
238-
.graph_module
255+
builder = GraphBuilder()
256+
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
257+
y = builder.call_operator(
258+
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
259+
)
260+
z = builder.call_operator(
261+
op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120])
239262
)
240-
graph_module.graph.eliminate_dead_code()
263+
t = builder.call_operator(
264+
op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1])
265+
)
266+
builder.output([z, t])
267+
original_graph = builder.get_graph_module()
268+
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
269+
converted_graph.graph.eliminate_dead_code()
241270
# z and t should be fused and y should be eliminated.
242271
self.assertEqual(
243-
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
272+
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2
244273
)
245274

246275
def test_force_quant_dequant_fusion(self):

0 commit comments

Comments
 (0)