Skip to content

Commit f354a96

Browse files
mcr229facebook-github-bot
authored andcommitted
new backend tests for lifted graphs
Summary: Copying all the tests from test_backends.py to use lifted graph instead. First step in migrating completely over to torch.export Reviewed By: cccclai Differential Revision: D47887945 fbshipit-source-id: 539771ab04389f4f605a8c39ad892e6ba9673369
1 parent fb602f8 commit f354a96

File tree

5 files changed

+1499
-46
lines changed

5 files changed

+1499
-46
lines changed

backends/test/TARGETS

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,37 @@ python_unittest(
112112
],
113113
)
114114

115+
python_unittest(
116+
name = "test_backends_lifted",
117+
srcs = [
118+
"test_backends_lifted.py",
119+
],
120+
supports_static_listing = True,
121+
deps = [
122+
":backend_with_compiler_demo",
123+
":hta_partitioner_demo",
124+
":op_partitioner_demo",
125+
":qnn_backend_demo",
126+
"//caffe2:torch",
127+
"//caffe2/functorch:functorch_src",
128+
"//executorch/backends:backend_api",
129+
"//executorch/backends:compile_spec_schema",
130+
"//executorch/backends:partitioner",
131+
"//executorch/exir:delegate",
132+
"//executorch/exir:graph_module",
133+
"//executorch/exir:lib",
134+
"//executorch/exir:lowered_backend_module",
135+
"//executorch/exir:print_program",
136+
"//executorch/exir:schema",
137+
"//executorch/exir/dialects:lib",
138+
"//executorch/extension/pybindings:portable", # @manual
139+
"//executorch/extension/pytree:pylib",
140+
"//executorch/kernels/portable:custom_ops_generated_lib",
141+
"//executorch/kernels/quantized:custom_ops_generated_lib",
142+
"//executorch/runtime/executor/test:test_backend_compiler_lib",
143+
],
144+
)
145+
115146
python_unittest(
116147
name = "test_graph_partition",
117148
srcs = [

backends/test/hta_partitioner_demo.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,37 @@ def forward(self, x_raw, h, c):
5858
input_h = torch.ones([1, 32])
5959
input_c = torch.ones([1, 32])
6060

61+
pattern_lstm_conv_lifted = (
62+
exir.capture(
63+
LSTMConvPattern(),
64+
(input_x, input_h, input_c),
65+
exir.CaptureConfig(pt2_mode=True, enable_aot=True),
66+
)
67+
.to_edge()
68+
.exported_program.graph_module
69+
)
6170
pattern_lstm_conv = (
6271
exir.capture(
6372
LSTMConvPattern(),
6473
(input_x, input_h, input_c),
6574
exir.CaptureConfig(pt2_mode=True),
6675
)
67-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
76+
.to_edge()
6877
.exported_program.graph_module
6978
)
7079

7180
def sub(x, y):
7281
return torch.sub(x, y)
7382

83+
pattern_sub_lifted = (
84+
exir.capture(
85+
sub,
86+
(input_x, input_h),
87+
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
88+
)
89+
.to_edge(exir.EdgeCompileConfig(_use_edge_ops=True))
90+
.exported_program.graph_module
91+
)
7492
pattern_sub = (
7593
exir.capture(
7694
sub,
@@ -80,7 +98,12 @@ def sub(x, y):
8098
.to_edge()
8199
.exported_program.graph_module
82100
)
83-
self.patterns = [pattern_lstm_conv.graph, pattern_sub.graph]
101+
self.patterns = [
102+
pattern_lstm_conv_lifted.graph,
103+
pattern_lstm_conv.graph,
104+
pattern_sub_lifted.graph,
105+
pattern_sub.graph,
106+
]
84107

85108
backend_id = QnnBackend.__name__
86109
self.delegation_spec = DelegationSpec(backend_id, [])
@@ -145,28 +168,18 @@ def generate_partition_list(self, graph_module) -> List[Partition]:
145168
]
146169
147170
"""
148-
partitions_from_all_pattern = [
149-
generate_pattern_op_partitions(graph_module, patterns=[pattern])
150-
for pattern in self.patterns
151-
]
152-
153-
# Check if all partitions are exclusive, this partitions don't support inclusive partitions.
154-
is_exclusive = self.is_exclusive(partitions_from_all_pattern)
155-
156-
assert (
157-
is_exclusive
158-
), "There exists inclusive partitions. Currently the fuse method only handle exclusive partitions."
171+
partitions_from_all_pattern = generate_pattern_op_partitions(
172+
graph_module, self.patterns
173+
)
159174

160175
# Assign a unique id for each partition
161176
partition_id = 0
162177

163-
# If want to support inclusive partitions, the logic can be done here to merge partitions etc.
164178
flat_proposed_partitions_with_unique_id = []
165-
for partitions_from_one_pattern in partitions_from_all_pattern:
166-
for partition in partitions_from_one_pattern:
167-
partition.id = partition_id
168-
flat_proposed_partitions_with_unique_id.append(partition)
169-
partition_id += 1
179+
for partition in partitions_from_all_pattern:
180+
partition.id = partition_id
181+
flat_proposed_partitions_with_unique_id.append(partition)
182+
partition_id += 1
170183

171184
return flat_proposed_partitions_with_unique_id
172185

@@ -213,16 +226,28 @@ def forward(self, x_raw, h, c):
213226
input_h = torch.ones([1, 32])
214227
input_c = torch.ones([1, 32])
215228

216-
pattern_lstm_conv = (
229+
pattern_lstm_conv_lifted = (
217230
exir.capture(
218231
LSTMConvPattern(),
219232
(input_x, input_h, input_c),
220-
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
233+
exir.CaptureConfig(pt2_mode=True, enable_aot=True),
234+
)
235+
.to_edge()
236+
.exported_program.graph_module
237+
)
238+
pattern_lstm_conv_unlifted = (
239+
exir.capture(
240+
LSTMConvPattern(),
241+
(input_x, input_h, input_c),
242+
exir.CaptureConfig(pt2_mode=True),
221243
)
222-
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
244+
.to_edge()
223245
.exported_program.graph_module
224246
)
225-
self.patterns = [pattern_lstm_conv.graph]
247+
self.patterns = [
248+
pattern_lstm_conv_lifted.graph,
249+
pattern_lstm_conv_unlifted.graph,
250+
]
226251
# Only (lstm + conv) pattern is lowerable
227252

228253
backend_id = QnnBackend.__name__

backends/test/test_backends.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -627,19 +627,11 @@ def forward(self, x_raw, h, c):
627627

628628
traced = exir.capture(
629629
composite_m, inputs, exir.CaptureConfig(pt2_mode=True)
630-
).to_edge(
631-
exir.EdgeCompileConfig(
632-
_check_ir_validity=False,
633-
)
634-
)
630+
).to_edge()
635631

636632
program_without_delegates = (
637633
exir.capture(CompositeModel(3), inputs)
638-
.to_edge(
639-
exir.EdgeCompileConfig(
640-
_check_ir_validity=False,
641-
)
642-
)
634+
.to_edge()
643635
.to_executorch(
644636
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
645637
)
@@ -741,20 +733,16 @@ def forward(self, x_raw, h, c):
741733
traced = exir.capture(
742734
composite_m,
743735
inputs,
744-
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
745-
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
736+
exir.CaptureConfig(pt2_mode=True),
737+
).to_edge()
746738

747739
program_without_delegates = (
748740
exir.capture(
749741
CompositeModel(3),
750742
(input_x, input_h, input_c),
751743
exir.CaptureConfig(pt2_mode=True),
752744
)
753-
.to_edge(
754-
exir.EdgeCompileConfig(
755-
_check_ir_validity=False,
756-
)
757-
)
745+
.to_edge()
758746
.to_executorch(
759747
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
760748
)
@@ -998,13 +986,8 @@ def test_quantized_with_delegate(self) -> None:
998986
exir.CaptureConfig(
999987
pt2_mode=True,
1000988
enable_aot=True,
1001-
_unlift=True,
1002989
),
1003-
).to_edge(
1004-
exir.EdgeCompileConfig(
1005-
_check_ir_validity=False,
1006-
)
1007-
)
990+
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
1008991
FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
1009992
converted_linear_gm.exported_program.graph_module.code
1010993
)

0 commit comments

Comments
 (0)