Skip to content

Commit 447b248

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from a bunch of tests (#2719)
Summary: exir.capture is deprecated. use export instead Reviewed By: zhxchen17 Differential Revision: D55436633
1 parent 65f3e18 commit 447b248

File tree

7 files changed

+77
-432
lines changed

7 files changed

+77
-432
lines changed

exir/backend/test/test_utils.py

Lines changed: 42 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch import exir
13-
from executorch.exir import CaptureConfig, to_edge
13+
from executorch.exir import to_edge
1414
from executorch.exir.backend.backend_api import to_backend
1515
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
1616
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
@@ -21,31 +21,14 @@
2121
get_non_lowered_nodes,
2222
is_identical_graph,
2323
print_delegated_graph,
24-
remove_first_quant_and_last_dequant,
25-
replace_quantized_partition_with_op,
2624
)
2725

2826
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
2927
from pandas.testing import assert_frame_equal
30-
from torch.ao.quantization import get_default_qconfig # @manual
31-
from torch.ao.quantization.backend_config.executorch import (
32-
get_executorch_backend_config,
33-
)
34-
from torch.ao.quantization.qconfig_mapping import (
35-
_get_symmetric_qnnpack_qconfig_mapping,
36-
QConfigMapping,
37-
)
38-
from torch.ao.quantization.quantize_fx import (
39-
_convert_to_reference_decomposed_fx,
40-
prepare_fx,
41-
)
42-
from torch.export import ExportedProgram
28+
from torch.export import export, ExportedProgram
4329
from torch.fx import symbolic_trace
44-
from torch.fx.passes.utils.fuser_utils import legalize_graph
4530
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
46-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
4731
from torch.library import Library
48-
from torch.testing import FileCheck
4932

5033
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
5134
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
@@ -115,22 +98,24 @@ def forward(self, x, y):
11598
return x + 1, y + 2
11699

117100
graph_module_1: torch.fx.GraphModule = (
118-
exir.capture(
119-
MyModule1(),
120-
(torch.rand(3, 4), torch.rand(3, 4)),
121-
CaptureConfig(),
101+
to_edge(
102+
export(
103+
MyModule1(),
104+
(torch.rand(3, 4), torch.rand(3, 4)),
105+
)
122106
)
123-
.to_edge()
124-
.exported_program.graph_module
107+
.exported_program()
108+
.graph_module
125109
)
126110
graph_module_2: torch.fx.GraphModule = (
127-
exir.capture(
128-
MyModule2(),
129-
(torch.rand(3, 4), torch.rand(3, 4)),
130-
CaptureConfig(),
111+
to_edge(
112+
export(
113+
MyModule2(),
114+
(torch.rand(3, 4), torch.rand(3, 4)),
115+
)
131116
)
132-
.to_edge()
133-
.exported_program.graph_module
117+
.exported_program()
118+
.graph_module
134119
)
135120
is_matched = is_identical_graph(graph_module_1, graph_module_2)
136121
self.assertFalse(is_matched)
@@ -149,40 +134,25 @@ def forward(self, x):
149134

150135
inputs = (torch.ones(3, 3),)
151136

152-
# Large model graph:
153-
# opcode name target args kwargs
154-
# ------------- ----------------- ------------------ ------------------------------------------- --------
155-
# placeholder ph_0 ph_0 () {}
156-
# get_attr _param_constant0 _param_constant0 () {}
157-
# call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
158-
# get_attr _param_constant1 _param_constant1 () {}
159-
# get_attr _tensor_constant0 _tensor_constant0 () {}
160-
# call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
161-
# output output output ([add_tensor, addmm_default],) {}
162-
163137
large_model = (
164-
exir.capture(
165-
LargeModel(),
166-
inputs,
167-
CaptureConfig(),
138+
to_edge(
139+
export(
140+
LargeModel(),
141+
inputs,
142+
),
143+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
168144
)
169-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
170-
.exported_program.graph_module
145+
.exported_program()
146+
.graph_module
171147
)
172148

173-
# Pattern graph:
174-
# opcode name target args kwargs
175-
# ------------- ----------------- ------------------ ------------------------------------------- --------
176-
# placeholder ph_0 ph_0 () {}
177-
# get_attr _param_constant0 _param_constant0 () {}
178-
# get_attr _tensor_constant0 _tensor_constant0 () {}
179-
# call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
180-
# output output output ([addmm_default],) {}
181-
182149
pattern = (
183-
exir.capture(torch.nn.Linear(3, 3), inputs, CaptureConfig())
184-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
185-
.exported_program.graph_module.graph
150+
to_edge(
151+
export(torch.nn.Linear(3, 3), inputs),
152+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
153+
)
154+
.exported_program()
155+
.graph_module.graph
186156
)
187157

188158
subgraph_matcher = SubgraphMatcher(pattern)
@@ -191,65 +161,6 @@ def forward(self, x):
191161
# Should find exact one match
192162
self.assertEqual(len(match_result), 1)
193163

194-
def test_remove_first_quant_and_last_dequant(self):
195-
qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
196-
linear = torch.nn.Linear(3, 4).eval()
197-
198-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
199-
prepared_linear = prepare_fx(
200-
linear,
201-
qconfig_mapping,
202-
example_inputs,
203-
backend_config=get_executorch_backend_config(),
204-
)
205-
206-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
207-
prepared_linear,
208-
)
209-
210-
actual_static_quant_linear = (
211-
exir.capture(
212-
converted_linear,
213-
example_inputs,
214-
CaptureConfig(
215-
enable_functionalization=False,
216-
),
217-
)
218-
.to_edge(
219-
exir.EdgeCompileConfig(
220-
_check_ir_validity=False,
221-
)
222-
)
223-
.exported_program.graph_module
224-
)
225-
226-
# Original graph has exactly 3 dequantize ops and 3 quantize ops
227-
FileCheck().check_count(
228-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
229-
3,
230-
exactly=True,
231-
).run(actual_static_quant_linear.code)
232-
FileCheck().check_count(
233-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
234-
3,
235-
exactly=True,
236-
).run(actual_static_quant_linear.code)
237-
238-
# Remove first and last dequant in static quant
239-
remove_first_quant_and_last_dequant(actual_static_quant_linear)
240-
241-
# Original graph has exactly 2 dequantize ops and 2 quantize ops
242-
FileCheck().check_count(
243-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
244-
2,
245-
exactly=True,
246-
).run(actual_static_quant_linear.code)
247-
FileCheck().check_count(
248-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
249-
2,
250-
exactly=True,
251-
).run(actual_static_quant_linear.code)
252-
253164
def test_invalid_partitioner_without_partitioner(self):
254165
"""
255166
Tests replacing literals with placeholders in the case there are
@@ -272,13 +183,10 @@ def partition(
272183
tagged_exported_program=edge_exported_program, partition_tags=None
273184
)
274185

275-
exported_program = exir.capture(
276-
torch.nn.Linear(3, 3),
277-
(torch.randn(3, 3),),
278-
CaptureConfig(),
279-
).to_edge(
280-
exir.EdgeCompileConfig(
281-
_check_ir_validity=False,
186+
exported_program = to_edge(
187+
export(
188+
torch.nn.Linear(3, 3),
189+
(torch.randn(3, 3),),
282190
)
283191
)
284192

@@ -287,7 +195,7 @@ def partition(
287195
AssertionError,
288196
error_msg,
289197
):
290-
_ = to_backend(exported_program.exported_program, InvalidPartitioner())
198+
_ = to_backend(exported_program.exported_program(), InvalidPartitioner())
291199

292200
test_lib = Library("test_lib", "DEF")
293201

@@ -298,74 +206,6 @@ def partition(
298206
def q_linear(x, weight, bias):
299207
return x
300208

301-
def test_replace_quantized_partition_with_op(self):
302-
class LinearModel(torch.nn.Module):
303-
def __init__(self):
304-
super().__init__()
305-
self.linear = torch.nn.Linear(3, 4)
306-
307-
def forward(self, input):
308-
return self.linear(input)
309-
310-
linear_model = LinearModel()
311-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
312-
prepared_linear = prepare_fx(
313-
linear_model,
314-
QConfigMapping().set_object_type(
315-
torch.nn.Linear,
316-
get_default_qconfig("qnnpack"),
317-
),
318-
example_inputs,
319-
backend_config=get_executorch_backend_config(),
320-
)
321-
322-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
323-
prepared_linear,
324-
)
325-
326-
actual_static_quant_linear = (
327-
exir.capture(
328-
converted_linear,
329-
example_inputs,
330-
CaptureConfig(
331-
enable_functionalization=False,
332-
),
333-
)
334-
.to_edge(
335-
exir.EdgeCompileConfig(
336-
_check_ir_validity=False,
337-
),
338-
)
339-
.exported_program.graph_module
340-
)
341-
342-
source_partitions_by_module = get_source_partitions(
343-
actual_static_quant_linear.graph,
344-
[torch.ao.nn.quantized.reference.modules.linear.Linear],
345-
)
346-
347-
replace_quantized_partition_with_op(
348-
actual_static_quant_linear,
349-
list(source_partitions_by_module.values())[0][0],
350-
torch.ops.test_lib.test_q_linear,
351-
)
352-
353-
legalize_graph(actual_static_quant_linear)
354-
355-
FileCheck().check_count(
356-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
357-
1,
358-
exactly=True,
359-
).run(actual_static_quant_linear.code)
360-
FileCheck().check_count(
361-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
362-
1,
363-
exactly=True,
364-
).run(actual_static_quant_linear.code)
365-
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
366-
actual_static_quant_linear.code
367-
)
368-
369209
def test_get_non_lowered_nodes(self):
370210
class Model(torch.nn.Module):
371211
def __init__(self):
@@ -381,12 +221,9 @@ def forward(self, a, x, b):
381221

382222
m = Model()
383223
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
384-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
385-
edge.exported_program = to_backend(
386-
edge.exported_program, AddMulPartitionerDemo()
387-
)
388-
edge.dump()
389-
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
224+
edge = to_edge(export(m, inputs))
225+
edge = edge.to_backend(AddMulPartitionerDemo())
226+
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph)
390227
# Only sub is not not lowerable
391228
self.assertEqual(len(number_of_cpu_nodes), 1)
392229

@@ -405,11 +242,9 @@ def forward(self, a, x, b):
405242

406243
m = Model()
407244
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
408-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
409-
edge.exported_program = to_backend(
410-
edge.exported_program, AddMulPartitionerDemo()
411-
)
412-
number_of_delegates = get_delegates(edge.exported_program.graph)
245+
edge = to_edge(export(m, inputs))
246+
edge = edge.to_backend(AddMulPartitionerDemo())
247+
number_of_delegates = get_delegates(edge.exported_program().graph)
413248
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
414249
self.assertEqual(len(number_of_delegates), 2)
415250

@@ -429,9 +264,7 @@ def forward(self, a, x, b):
429264
m = Model()
430265
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
431266

432-
edge = to_edge(torch.export.export(m, inputs)).to_backend(
433-
AddMulPartitionerDemo()
434-
)
267+
edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo())
435268

436269
graph_str = print_delegated_graph(edge.exported_program().graph_module)
437270
self.assertIn(

exir/tests/TARGETS

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
33
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
44
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
55

6+
oncall("executorch")
7+
68
python_library(
79
name = "lib",
810
srcs = [
911
"common.py",
1012
],
1113
deps = [
1214
"//caffe2:torch",
13-
"//executorch/exir:lib",
1415
"//executorch/exir:schema",
1516
],
1617
)
@@ -162,35 +163,16 @@ python_unittest(
162163
supports_static_listing = False,
163164
deps = [
164165
"fbsource//third-party/pypi/parameterized:parameterized",
165-
":asr_joiner",
166166
"//caffe2:torch",
167-
"//executorch/backends/fb/qnnpack/partition:qnnpack_partitioner",
168167
"//executorch/exir:lib",
169168
"//executorch/exir:memory_planning",
170169
"//executorch/exir:pass_base",
171170
"//executorch/exir:pass_manager",
172-
"//executorch/exir:print_program",
173-
"//executorch/exir:schema",
174-
"//executorch/exir/backend:backend_api",
175171
"//executorch/exir/passes:lib",
176172
"//executorch/exir/passes:sym_shape_eval_pass",
177173
],
178174
)
179175

180-
python_unittest(
181-
name = "experimental",
182-
srcs = [
183-
"test_experimental.py",
184-
],
185-
deps = [
186-
"//caffe2:torch",
187-
"//executorch/exir:error",
188-
"//executorch/exir:lib",
189-
"//executorch/exir/experimental:export_pt2",
190-
"//executorch/exir/experimental:lib",
191-
],
192-
)
193-
194176
python_unittest(
195177
name = "passes",
196178
srcs = [

0 commit comments

Comments
 (0)