Skip to content

Commit 4ca3b05

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from quant fusion test
Summary: title Differential Revision: D56264730
1 parent 78cb141 commit 4ca3b05

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

exir/tests/test_quant_fusion_pass.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch import exir
13-
from executorch.exir import CaptureConfig, EdgeCompileConfig
13+
from executorch.exir import EdgeCompileConfig, to_edge
1414
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
1515
from executorch.exir.tests.common import register_additional_test_aten_ops
1616
from torch.ao.quantization import ( # @manual
@@ -26,6 +26,7 @@
2626
_convert_to_reference_decomposed_fx,
2727
prepare_fx,
2828
)
29+
from torch.export import export
2930
from torch.nn import functional as F
3031

3132
from torch.testing import FileCheck
@@ -56,9 +57,11 @@ def forward(self, x, y):
5657
)
5758
m = _convert_to_reference_decomposed_fx(m)
5859
config = EdgeCompileConfig(_check_ir_validity=False)
59-
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
60+
m = to_edge(export(m, example_inputs), compile_config=config)
6061
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
61-
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
62+
m = m.transform(
63+
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
64+
)
6265
# check that we are using functional variant of q/dq/add
6366
FileCheck().check(
6467
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
@@ -67,12 +70,12 @@ def forward(self, x, y):
6770
).check(
6871
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
6972
).run(
70-
m.exported_program.graph_module.code
73+
m.exported_program().graph_module.code
7174
)
7275
m = m.to_executorch()
7376
# check that we are using out variant of q/dq/add
7477
FileCheck().check("torch.ops.quantized_decomposed.add.out").run(
75-
m.exported_program.graph_module.code
78+
m.exported_program().graph_module.code
7679
)
7780

7881
def test_reshape(self) -> None:
@@ -95,9 +98,11 @@ def forward(self, x, y):
9598
m(*example_inputs)
9699
m = _convert_to_reference_decomposed_fx(m)
97100
config = EdgeCompileConfig(_check_ir_validity=False)
98-
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
101+
m = to_edge(export(m, example_inputs), compile_config=config)
99102
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
100-
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
103+
m = m.transform(
104+
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
105+
)
101106
# check that we are using functional variant of q/dq/add/reshape
102107
# make sure we only have two quant and one dequant since the q/dq around reshape
103108
# should be fused
@@ -114,14 +119,14 @@ def forward(self, x, y):
114119
1,
115120
exactly=True,
116121
).run(
117-
m.exported_program.graph_module.code
122+
m.exported_program().graph_module.code
118123
)
119124

120125
m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False))
121126
# check that we are using out variant of q/dq/add
122127
FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
123128
"torch.ops.aten.view_copy.out"
124-
).run(m.exported_program.graph_module.code)
129+
).run(m.exported_program().graph_module.code)
125130

126131
def test_slice(self) -> None:
127132
"""We don't proactively quantize slice today, but we'll fuse the dq-slice-q
@@ -150,9 +155,11 @@ def forward(self, x, y):
150155
)
151156
m = _convert_to_reference_decomposed_fx(m)
152157
config = EdgeCompileConfig(_check_ir_validity=False)
153-
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
158+
m = to_edge(export(m, example_inputs), compile_config=config)
154159
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
155-
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
160+
m = m.transform(
161+
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
162+
)
156163
# check that we are using functional variant of q/dq/add/slice
157164
# make sure we only have one quant and one dequant since the q/dq around slice
158165
# should be fused
@@ -169,14 +176,14 @@ def forward(self, x, y):
169176
).check(
170177
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
171178
).run(
172-
m.exported_program.graph_module.code
179+
m.exported_program().graph_module.code
173180
)
174181

175182
m = m.to_executorch()
176183
# check that we are using out variant of add and slice_copy
177184
FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
178185
"torch.ops.aten.slice_copy.Tensor_out"
179-
).run(m.dump_graph_module().code)
186+
).run(m.exported_program().graph_module.code)
180187

181188
def test_cat(self) -> None:
182189
class M(torch.nn.Module):
@@ -197,9 +204,9 @@ def forward(self, x, y):
197204
m(*example_inputs)
198205
m = _convert_to_reference_decomposed_fx(m)
199206
config = EdgeCompileConfig(_check_ir_validity=False)
200-
m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config)
207+
m = to_edge(export(m, example_inputs), compile_config=config)
201208
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
202-
m = m.transform(QuantFusionPass())
209+
m = m.transform([QuantFusionPass()], check_ir_validity=False)
203210
# check that we are using functional variant of q/dq/cat
204211
FileCheck().check_count(
205212
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
@@ -210,7 +217,7 @@ def forward(self, x, y):
210217
1,
211218
exactly=True,
212219
).run(
213-
m.exported_program.graph_module.code
220+
m.exported_program().graph_module.code
214221
)
215222

216223
m = m.to_executorch()
@@ -224,7 +231,7 @@ def forward(self, x, y):
224231
).check("torch.ops.aten.cat.out").check_count(
225232
"torch.ops.quantized_decomposed.dequantize_per_tensor.out", 1, exactly=True
226233
).run(
227-
m.dump_graph_module().code
234+
m.exported_program().graph_module.code
228235
)
229236

230237
def test_embedding_byte(self) -> None:
@@ -292,16 +299,18 @@ def forward(self, indices):
292299
_check_ir_validity=False,
293300
_use_edge_ops=True,
294301
)
295-
m = exir.capture(m, example_inputs).to_edge(config=compile_config)
302+
m = to_edge(export(m, example_inputs), compile_config=compile_config)
296303
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
297-
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
304+
m = m.transform(
305+
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
306+
)
298307
# check that we are using functional variant of q/dq/cat
299308
FileCheck().check(
300309
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
301310
).check(
302311
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
303312
).run(
304-
m.exported_program.graph_module.code
313+
m.exported_program().graph_module.code
305314
)
306315

307316
# TODO: enable after the out variants of quantize_per_channel is supported
@@ -348,17 +357,18 @@ def forward(self, indices):
348357
_check_ir_validity=False,
349358
_use_edge_ops=True,
350359
)
351-
m = exir.capture(m, example_inputs).to_edge(config=compile_config)
360+
m = to_edge(export(m, example_inputs), compile_config=compile_config)
352361
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
353-
m = m.transform(QuantFusionPass(_fix_node_meta_val=True))
354-
m(*example_inputs)
362+
m = m.transform(
363+
[QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False
364+
)
355365
# check that we are using functional variant of q/dq/cat
356366
FileCheck().check(
357367
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
358368
).check(
359369
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
360370
).run(
361-
m.exported_program.graph_module.code
371+
m.exported_program().graph_module.code
362372
)
363373

364374
# TODO: enable after the out variants of quantize_per_channel is supported

0 commit comments

Comments
 (0)