Skip to content

Commit eedd262

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from a bunch of tests (#2719)
Summary: exir.capture is deprecated. use export instead Differential Revision: D55436633
1 parent 79a4ba0 commit eedd262

File tree

6 files changed

+73
-404
lines changed

6 files changed

+73
-404
lines changed

exir/backend/test/test_utils.py

Lines changed: 41 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
_convert_to_reference_decomposed_fx,
3535
prepare_fx,
3636
)
37-
from torch.export import ExportedProgram
37+
from torch.export import export, ExportedProgram
3838
from torch.fx import symbolic_trace
3939
from torch.fx.passes.utils.fuser_utils import legalize_graph
4040
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -110,22 +110,24 @@ def forward(self, x, y):
110110
return x + 1, y + 2
111111

112112
graph_module_1: torch.fx.GraphModule = (
113-
exir.capture(
114-
MyModule1(),
115-
(torch.rand(3, 4), torch.rand(3, 4)),
116-
CaptureConfig(),
113+
to_edge(
114+
export(
115+
MyModule1(),
116+
(torch.rand(3, 4), torch.rand(3, 4)),
117+
)
117118
)
118-
.to_edge()
119-
.exported_program.graph_module
119+
.exported_program()
120+
.graph_module
120121
)
121122
graph_module_2: torch.fx.GraphModule = (
122-
exir.capture(
123-
MyModule2(),
124-
(torch.rand(3, 4), torch.rand(3, 4)),
125-
CaptureConfig(),
123+
to_edge(
124+
export(
125+
MyModule2(),
126+
(torch.rand(3, 4), torch.rand(3, 4)),
127+
)
126128
)
127-
.to_edge()
128-
.exported_program.graph_module
129+
.exported_program()
130+
.graph_module
129131
)
130132
is_matched = is_identical_graph(graph_module_1, graph_module_2)
131133
self.assertFalse(is_matched)
@@ -144,40 +146,25 @@ def forward(self, x):
144146

145147
inputs = (torch.ones(3, 3),)
146148

147-
# Large model graph:
148-
# opcode name target args kwargs
149-
# ------------- ----------------- ------------------ ------------------------------------------- --------
150-
# placeholder ph_0 ph_0 () {}
151-
# get_attr _param_constant0 _param_constant0 () {}
152-
# call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
153-
# get_attr _param_constant1 _param_constant1 () {}
154-
# get_attr _tensor_constant0 _tensor_constant0 () {}
155-
# call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
156-
# output output output ([add_tensor, addmm_default],) {}
157-
158149
large_model = (
159-
exir.capture(
160-
LargeModel(),
161-
inputs,
162-
CaptureConfig(),
150+
to_edge(
151+
export(
152+
LargeModel(),
153+
inputs,
154+
),
155+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
163156
)
164-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
165-
.exported_program.graph_module
157+
.exported_program()
158+
.graph_module
166159
)
167160

168-
# Pattern graph:
169-
# opcode name target args kwargs
170-
# ------------- ----------------- ------------------ ------------------------------------------- --------
171-
# placeholder ph_0 ph_0 () {}
172-
# get_attr _param_constant0 _param_constant0 () {}
173-
# get_attr _tensor_constant0 _tensor_constant0 () {}
174-
# call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
175-
# output output output ([addmm_default],) {}
176-
177161
pattern = (
178-
exir.capture(torch.nn.Linear(3, 3), inputs, CaptureConfig())
179-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
180-
.exported_program.graph_module.graph
162+
to_edge(
163+
export(torch.nn.Linear(3, 3), inputs),
164+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
165+
)
166+
.exported_program()
167+
.graph_module.graph
181168
)
182169

183170
subgraph_matcher = SubgraphMatcher(pattern)
@@ -186,65 +173,6 @@ def forward(self, x):
186173
# Should find exact one match
187174
self.assertEqual(len(match_result), 1)
188175

189-
def test_remove_first_quant_and_last_dequant(self):
190-
qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
191-
linear = torch.nn.Linear(3, 4).eval()
192-
193-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
194-
prepared_linear = prepare_fx(
195-
linear,
196-
qconfig_mapping,
197-
example_inputs,
198-
backend_config=get_executorch_backend_config(),
199-
)
200-
201-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
202-
prepared_linear,
203-
)
204-
205-
actual_static_quant_linear = (
206-
exir.capture(
207-
converted_linear,
208-
example_inputs,
209-
CaptureConfig(
210-
enable_functionalization=False,
211-
),
212-
)
213-
.to_edge(
214-
exir.EdgeCompileConfig(
215-
_check_ir_validity=False,
216-
)
217-
)
218-
.exported_program.graph_module
219-
)
220-
221-
# Original graph has exactly 3 dequantize ops and 3 quantize ops
222-
FileCheck().check_count(
223-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
224-
3,
225-
exactly=True,
226-
).run(actual_static_quant_linear.code)
227-
FileCheck().check_count(
228-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
229-
3,
230-
exactly=True,
231-
).run(actual_static_quant_linear.code)
232-
233-
# Remove first and last dequant in static quant
234-
remove_first_quant_and_last_dequant(actual_static_quant_linear)
235-
236-
# Original graph has exactly 2 dequantize ops and 2 quantize ops
237-
FileCheck().check_count(
238-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
239-
2,
240-
exactly=True,
241-
).run(actual_static_quant_linear.code)
242-
FileCheck().check_count(
243-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
244-
2,
245-
exactly=True,
246-
).run(actual_static_quant_linear.code)
247-
248176
def test_invalid_partitioner_without_partitioner(self):
249177
"""
250178
Tests replacing literals with placeholders in the case there are
@@ -267,13 +195,10 @@ def partition(
267195
tagged_exported_program=edge_exported_program, partition_tags=None
268196
)
269197

270-
exported_program = exir.capture(
271-
torch.nn.Linear(3, 3),
272-
(torch.randn(3, 3),),
273-
CaptureConfig(),
274-
).to_edge(
275-
exir.EdgeCompileConfig(
276-
_check_ir_validity=False,
198+
exported_program = to_edge(
199+
export(
200+
torch.nn.Linear(3, 3),
201+
(torch.randn(3, 3),),
277202
)
278203
)
279204

@@ -282,7 +207,7 @@ def partition(
282207
AssertionError,
283208
error_msg,
284209
):
285-
_ = to_backend(exported_program.exported_program, InvalidPartitioner())
210+
_ = to_backend(exported_program.exported_program(), InvalidPartitioner())
286211

287212
test_lib = Library("test_lib", "DEF")
288213

@@ -293,74 +218,6 @@ def partition(
293218
def q_linear(x, weight, bias):
294219
return x
295220

296-
def test_replace_quantized_partition_with_op(self):
297-
class LinearModel(torch.nn.Module):
298-
def __init__(self):
299-
super().__init__()
300-
self.linear = torch.nn.Linear(3, 4)
301-
302-
def forward(self, input):
303-
return self.linear(input)
304-
305-
linear_model = LinearModel()
306-
example_inputs = (torch.ones(1, 1, 3, dtype=torch.float),)
307-
prepared_linear = prepare_fx(
308-
linear_model,
309-
QConfigMapping().set_object_type(
310-
torch.nn.Linear,
311-
get_default_qconfig("qnnpack"),
312-
),
313-
example_inputs,
314-
backend_config=get_executorch_backend_config(),
315-
)
316-
317-
converted_linear: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
318-
prepared_linear,
319-
)
320-
321-
actual_static_quant_linear = (
322-
exir.capture(
323-
converted_linear,
324-
example_inputs,
325-
CaptureConfig(
326-
enable_functionalization=False,
327-
),
328-
)
329-
.to_edge(
330-
exir.EdgeCompileConfig(
331-
_check_ir_validity=False,
332-
),
333-
)
334-
.exported_program.graph_module
335-
)
336-
337-
source_partitions_by_module = get_source_partitions(
338-
actual_static_quant_linear.graph,
339-
[torch.ao.nn.quantized.reference.modules.linear.Linear],
340-
)
341-
342-
replace_quantized_partition_with_op(
343-
actual_static_quant_linear,
344-
list(source_partitions_by_module.values())[0][0],
345-
torch.ops.test_lib.test_q_linear,
346-
)
347-
348-
legalize_graph(actual_static_quant_linear)
349-
350-
FileCheck().check_count(
351-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
352-
1,
353-
exactly=True,
354-
).run(actual_static_quant_linear.code)
355-
FileCheck().check_count(
356-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
357-
1,
358-
exactly=True,
359-
).run(actual_static_quant_linear.code)
360-
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
361-
actual_static_quant_linear.code
362-
)
363-
364221
def test_get_non_lowered_nodes(self):
365222
class Model(torch.nn.Module):
366223
def __init__(self):
@@ -376,12 +233,9 @@ def forward(self, a, x, b):
376233

377234
m = Model()
378235
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
379-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
380-
edge.exported_program = to_backend(
381-
edge.exported_program, AddMulPartitionerDemo()
382-
)
383-
edge.dump()
384-
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
236+
edge = to_edge(export(m, inputs))
237+
edge = edge.to_backend(AddMulPartitionerDemo())
238+
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph)
385239
# Only sub is not not lowerable
386240
self.assertEqual(len(number_of_cpu_nodes), 1)
387241

@@ -400,11 +254,9 @@ def forward(self, a, x, b):
400254

401255
m = Model()
402256
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
403-
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
404-
edge.exported_program = to_backend(
405-
edge.exported_program, AddMulPartitionerDemo()
406-
)
407-
number_of_delegates = get_delegates(edge.exported_program.graph)
257+
edge = to_edge(export(m, inputs))
258+
edge = edge.to_backend(AddMulPartitionerDemo())
259+
number_of_delegates = get_delegates(edge.exported_program().graph)
408260
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
409261
self.assertEqual(len(number_of_delegates), 2)
410262

@@ -424,9 +276,7 @@ def forward(self, a, x, b):
424276
m = Model()
425277
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
426278

427-
edge = to_edge(torch.export.export(m, inputs)).to_backend(
428-
AddMulPartitionerDemo()
429-
)
279+
edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo())
430280

431281
graph_str = print_delegated_graph(edge.exported_program().graph_module)
432282
self.assertIn(

exir/tests/TARGETS

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,6 @@ python_unittest(
177177
],
178178
)
179179

180-
python_unittest(
181-
name = "experimental",
182-
srcs = [
183-
"test_experimental.py",
184-
],
185-
deps = [
186-
"//caffe2:torch",
187-
"//executorch/exir:error",
188-
"//executorch/exir:lib",
189-
"//executorch/exir/experimental:export_pt2",
190-
"//executorch/exir/experimental:lib",
191-
],
192-
)
193-
194180
python_unittest(
195181
name = "passes",
196182
srcs = [

exir/tests/common.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,37 +86,6 @@ def get_test_program() -> Program:
8686
)
8787

8888

89-
# pyre-ignore
90-
def get_graph_module_with_op(op: Callable, args: Any) -> torch.fx.GraphModule:
91-
"""
92-
Constructs an torch.fx.GraphModule containing just a call to the given op.
93-
94-
Args:
95-
op: A callable op
96-
args: Sample arguments to this given op
97-
98-
Returns:
99-
torch.fx.GraphModule with a graph like: inputs -> op -> output
100-
"""
101-
102-
trace_args, in_spec = pytree.tree_flatten(args)
103-
104-
graph = torch.fx.Graph()
105-
with graph.inserting_before(graph._root):
106-
input_nodes = []
107-
for i in range(len(trace_args)):
108-
input_nodes.append(graph.placeholder(f"ph_{i}"))
109-
110-
op_node = graph.call_function(op, tuple(input_nodes))
111-
graph.output(op_node)
112-
113-
graph_module = torch.fx.GraphModule(torch.nn.Module(), graph)
114-
graph_module.recompile()
115-
116-
graph_module = exir.capture(graph_module, args).to_edge().module
117-
return graph_module
118-
119-
12089
def register_additional_test_aten_ops() -> None:
12190
# TODO: either mark those ops as canonical in native_functions.yaml,
12291
# or stop using graphs with those in tests.

0 commit comments

Comments
 (0)