Skip to content

Commit 8ad2094

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from a bunch of tests (#2719)
Summary: exir.capture is deprecated. use export instead Differential Revision: D55436633
1 parent 66c5fc8 commit 8ad2094

File tree

6 files changed

+76
-431
lines changed

6 files changed

+76
-431
lines changed

exir/backend/test/test_utils.py

Lines changed: 42 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from executorch import exir
11-
from executorch.exir import CaptureConfig, to_edge
11+
from executorch.exir import to_edge
1212
from executorch.exir.backend.backend_api import to_backend
1313
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
1414
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
@@ -17,30 +17,13 @@
1717
get_non_lowered_nodes,
1818
is_identical_graph,
1919
print_delegated_graph,
20-
remove_first_quant_and_last_dequant,
21-
replace_quantized_partition_with_op,
2220
)
2321

2422
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
25-
from torch.ao.quantization import get_default_qconfig # @manual
26-
from torch.ao.quantization.backend_config.executorch import (
27-
get_executorch_backend_config,
28-
)
29-
from torch.ao.quantization.qconfig_mapping import (
30-
_get_symmetric_qnnpack_qconfig_mapping,
31-
QConfigMapping,
32-
)
33-
from torch.ao.quantization.quantize_fx import (
34-
_convert_to_reference_decomposed_fx,
35-
prepare_fx,
36-
)
37-
from torch.export import ExportedProgram
23+
from torch.export import export, ExportedProgram
3824
from torch.fx import symbolic_trace
39-
from torch.fx.passes.utils.fuser_utils import legalize_graph
4025
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
41-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
4226
from torch.library import Library
43-
from torch.testing import FileCheck
4427

4528
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
4629
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
@@ -110,22 +93,24 @@ def forward(self, x, y):
11093
return x + 1, y + 2
11194

11295
graph_module_1: torch.fx.GraphModule = (
113-
exir.capture(
114-
MyModule1(),
115-
(torch.rand(3, 4), torch.rand(3, 4)),
116-
CaptureConfig(),
96+
to_edge(
97+
export(
98+
MyModule1(),
99+
(torch.rand(3, 4), torch.rand(3, 4)),
100+
)
117101
)
118-
.to_edge()
119-
.exported_program.graph_module
102+
.exported_program()
103+
.graph_module
120104
)
121105
graph_module_2: torch.fx.GraphModule = (
122-
exir.capture(
123-
MyModule2(),
124-
(torch.rand(3, 4), torch.rand(3, 4)),
125-
CaptureConfig(),
106+
to_edge(
107+
export(
108+
MyModule2(),
109+
(torch.rand(3, 4), torch.rand(3, 4)),
110+
)
126111
)
127-
.to_edge()
128-
.exported_program.graph_module
112+
.exported_program()
113+
.graph_module
129114
)
130115
is_matched = is_identical_graph(graph_module_1, graph_module_2)
131116
self.assertFalse(is_matched)
@@ -144,40 +129,25 @@ def forward(self, x):
144129

145130
inputs = (torch.ones(3, 3),)
146131

147-
# Large model graph:
148-
# opcode name target args kwargs
149-
# ------------- ----------------- ------------------ ------------------------------------------- --------
150-
# placeholder ph_0 ph_0 () {}
151-
# get_attr _param_constant0 _param_constant0 () {}
152-
# call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
153-
# get_attr _param_constant1 _param_constant1 () {}
154-
# get_attr _tensor_constant0 _tensor_constant0 () {}
155-
# call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
156-
# output output output ([add_tensor, addmm_default],) {}
157-
158132
large_model = (
159-
exir.capture(
160-
LargeModel(),
161-
inputs,
162-
CaptureConfig(),
133+
to_edge(
134+
export(
135+
LargeModel(),
136+
inputs,
137+
),
138+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
163139
)
164-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
165-
.exported_program.graph_module
140+
.exported_program()
141+
.graph_module
166142
)
167143

168-
# Pattern graph:
169-
# opcode name target args kwargs
170-
# ------------- ----------------- ------------------ ------------------------------------------- --------
171-
# placeholder ph_0 ph_0 () {}
172-
# get_attr _param_constant0 _param_constant0 () {}
173-
# get_attr _tensor_constant0 _tensor_constant0 () {}
174-
# call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
175-
# output output output ([addmm_default],) {}
176-
177144
pattern = (
178-
exir.capture(torch.nn.Linear(3, 3), inputs, CaptureConfig())
179-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
180-
.exported_program.graph_module.graph
145+
to_edge(
146+
export(torch.nn.Linear(3, 3), inputs),
147+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
148+
)
149+
.exported_program()
150+
.graph_module.graph
181151
)
182152

183153
subgraph_matcher = SubgraphMatcher(pattern)
@@ -186,65 +156,6 @@ def forward(self, x):
186156
# Should find exact one match
187157
self.assertEqual(len(match_result), 1)
188158

189-
def test_remove_first_quant_and_last_dequant(self):
190-
qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
191-
linear = torch.nn.Linear(3, 4).eval()
192-
193-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
194-
prepared_linear = prepare_fx(
195-
linear,
196-
qconfig_mapping,
197-
example_inputs,
198-
backend_config=get_executorch_backend_config(),
199-
)
200-
201-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
202-
prepared_linear,
203-
)
204-
205-
actual_static_quant_linear = (
206-
exir.capture(
207-
converted_linear,
208-
example_inputs,
209-
CaptureConfig(
210-
enable_functionalization=False,
211-
),
212-
)
213-
.to_edge(
214-
exir.EdgeCompileConfig(
215-
_check_ir_validity=False,
216-
)
217-
)
218-
.exported_program.graph_module
219-
)
220-
221-
# Original graph has exactly 3 dequantize ops and 3 quantize ops
222-
FileCheck().check_count(
223-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
224-
3,
225-
exactly=True,
226-
).run(actual_static_quant_linear.code)
227-
FileCheck().check_count(
228-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
229-
3,
230-
exactly=True,
231-
).run(actual_static_quant_linear.code)
232-
233-
# Remove first and last dequant in static quant
234-
remove_first_quant_and_last_dequant(actual_static_quant_linear)
235-
236-
# Original graph has exactly 2 dequantize ops and 2 quantize ops
237-
FileCheck().check_count(
238-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
239-
2,
240-
exactly=True,
241-
).run(actual_static_quant_linear.code)
242-
FileCheck().check_count(
243-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
244-
2,
245-
exactly=True,
246-
).run(actual_static_quant_linear.code)
247-
248159
def test_invalid_partitioner_without_partitioner(self):
249160
"""
250161
Tests replacing literals with placeholders in the case there are
@@ -267,13 +178,10 @@ def partition(
267178
tagged_exported_program=edge_exported_program, partition_tags=None
268179
)
269180

270-
exported_program = exir.capture(
271-
torch.nn.Linear(3, 3),
272-
(torch.randn(3, 3),),
273-
CaptureConfig(),
274-
).to_edge(
275-
exir.EdgeCompileConfig(
276-
_check_ir_validity=False,
181+
exported_program = to_edge(
182+
export(
183+
torch.nn.Linear(3, 3),
184+
(torch.randn(3, 3),),
277185
)
278186
)
279187

@@ -282,7 +190,7 @@ def partition(
282190
AssertionError,
283191
error_msg,
284192
):
285-
_ = to_backend(exported_program.exported_program, InvalidPartitioner())
193+
_ = to_backend(exported_program.exported_program(), InvalidPartitioner())
286194

287195
test_lib = Library("test_lib", "DEF")
288196

@@ -293,74 +201,6 @@ def partition(
293201
def q_linear(x, weight, bias):
294202
return x
295203

296-
def test_replace_quantized_partition_with_op(self):
297-
class LinearModel(torch.nn.Module):
298-
def __init__(self):
299-
super().__init__()
300-
self.linear = torch.nn.Linear(3, 4)
301-
302-
def forward(self, input):
303-
return self.linear(input)
304-
305-
linear_model = LinearModel()
306-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
307-
prepared_linear = prepare_fx(
308-
linear_model,
309-
QConfigMapping().set_object_type(
310-
torch.nn.Linear,
311-
get_default_qconfig("qnnpack"),
312-
),
313-
example_inputs,
314-
backend_config=get_executorch_backend_config(),
315-
)
316-
317-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
318-
prepared_linear,
319-
)
320-
321-
actual_static_quant_linear = (
322-
exir.capture(
323-
converted_linear,
324-
example_inputs,
325-
CaptureConfig(
326-
enable_functionalization=False,
327-
),
328-
)
329-
.to_edge(
330-
exir.EdgeCompileConfig(
331-
_check_ir_validity=False,
332-
),
333-
)
334-
.exported_program.graph_module
335-
)
336-
337-
source_partitions_by_module = get_source_partitions(
338-
actual_static_quant_linear.graph,
339-
[torch.ao.nn.quantized.reference.modules.linear.Linear],
340-
)
341-
342-
replace_quantized_partition_with_op(
343-
actual_static_quant_linear,
344-
list(source_partitions_by_module.values())[0][0],
345-
torch.ops.test_lib.test_q_linear,
346-
)
347-
348-
legalize_graph(actual_static_quant_linear)
349-
350-
FileCheck().check_count(
351-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
352-
1,
353-
exactly=True,
354-
).run(actual_static_quant_linear.code)
355-
FileCheck().check_count(
356-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
357-
1,
358-
exactly=True,
359-
).run(actual_static_quant_linear.code)
360-
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
361-
actual_static_quant_linear.code
362-
)
363-
364204
def test_get_non_lowered_nodes(self):
365205
class Model(torch.nn.Module):
366206
def __init__(self):
@@ -376,12 +216,9 @@ def forward(self, a, x, b):
376216

377217
m = Model()
378218
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
379-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
380-
edge.exported_program = to_backend(
381-
edge.exported_program, AddMulPartitionerDemo()
382-
)
383-
edge.dump()
384-
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
219+
edge = to_edge(export(m, inputs))
220+
edge = edge.to_backend(AddMulPartitionerDemo())
221+
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph)
385222
# Only sub is not not lowerable
386223
self.assertEqual(len(number_of_cpu_nodes), 1)
387224

@@ -400,11 +237,9 @@ def forward(self, a, x, b):
400237

401238
m = Model()
402239
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
403-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
404-
edge.exported_program = to_backend(
405-
edge.exported_program, AddMulPartitionerDemo()
406-
)
407-
number_of_delegates = get_delegates(edge.exported_program.graph)
240+
edge = to_edge(export(m, inputs))
241+
edge = edge.to_backend(AddMulPartitionerDemo())
242+
number_of_delegates = get_delegates(edge.exported_program().graph)
408243
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
409244
self.assertEqual(len(number_of_delegates), 2)
410245

@@ -424,9 +259,7 @@ def forward(self, a, x, b):
424259
m = Model()
425260
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
426261

427-
edge = to_edge(torch.export.export(m, inputs)).to_backend(
428-
AddMulPartitionerDemo()
429-
)
262+
edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo())
430263

431264
graph_str = print_delegated_graph(edge.exported_program().graph_module)
432265
self.assertIn(

exir/tests/TARGETS

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,35 +162,16 @@ python_unittest(
162162
supports_static_listing = False,
163163
deps = [
164164
"fbsource//third-party/pypi/parameterized:parameterized",
165-
":asr_joiner",
166165
"//caffe2:torch",
167-
"//executorch/backends/fb/qnnpack/partition:qnnpack_partitioner",
168166
"//executorch/exir:lib",
169167
"//executorch/exir:memory_planning",
170168
"//executorch/exir:pass_base",
171169
"//executorch/exir:pass_manager",
172-
"//executorch/exir:print_program",
173-
"//executorch/exir:schema",
174-
"//executorch/exir/backend:backend_api",
175170
"//executorch/exir/passes:lib",
176171
"//executorch/exir/passes:sym_shape_eval_pass",
177172
],
178173
)
179174

180-
python_unittest(
181-
name = "experimental",
182-
srcs = [
183-
"test_experimental.py",
184-
],
185-
deps = [
186-
"//caffe2:torch",
187-
"//executorch/exir:error",
188-
"//executorch/exir:lib",
189-
"//executorch/exir/experimental:export_pt2",
190-
"//executorch/exir/experimental:lib",
191-
],
192-
)
193-
194175
python_unittest(
195176
name = "passes",
196177
srcs = [
@@ -251,7 +232,6 @@ python_unittest(
251232
"//caffe2:torch",
252233
"//executorch/exir:lib",
253234
"//executorch/exir:pass_base",
254-
"//executorch/exir/backend:backend_api",
255235
"//executorch/exir/backend:backend_details",
256236
"//executorch/exir/backend:compile_spec_schema",
257237
"//executorch/exir/backend:partitioner",
@@ -430,6 +410,7 @@ python_unittest(
430410
],
431411
deps = [
432412
"//caffe2:torch",
413+
"//executorch/exir:dim_order_utils",
433414
"//executorch/exir:lib",
434415
],
435416
)

0 commit comments

Comments
 (0)