Skip to content

Commit fba2655

Browse files
dbortfacebook-github-bot
authored andcommitted
Remove direct calls to serialize_to_flatbuffer
Summary: Use `to_executorch().buffer` instead. Reviewed By: JacobSzwejbka Differential Revision: D48366593 fbshipit-source-id: 1a5bc5b13a6fb053d26e72fe15d988a7d20715b8
1 parent ffc28e6 commit fba2655

File tree

1 file changed

+22
-34
lines changed

1 file changed

+22
-34
lines changed

backends/qnnpack/test/test_qnnpack.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515

1616
# import the xnnpack backend implementation
1717
from executorch.backends.qnnpack.qnnpack_preprocess import QnnpackBackend
18-
from executorch.exir import CaptureConfig
18+
from executorch.exir import CaptureConfig, ExecutorchProgram
1919

2020
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2121

22-
from executorch.exir.serialize import serialize_to_flatbuffer
23-
2422
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
2523
from executorch.extension.pybindings.portable import ( # @manual
2624
_load_for_executorch_from_buffer,
@@ -114,20 +112,18 @@ def forward(self, x):
114112
example_inputs = (torch.rand(self.input_dims),)
115113

116114
composite_model(*example_inputs)
117-
program = (
115+
executorch_program: ExecutorchProgram = (
118116
exir.capture(composite_model, example_inputs, exir.CaptureConfig())
119117
.to_edge(EDGE_COMPILE_CONFIG)
120118
.to_executorch(config=EXECUTORCH_BACKEND_CONFIG)
121-
.program
122119
)
123120
self.assertEqual(
124-
program.execution_plan[0].delegates[0].id,
121+
executorch_program.program.execution_plan[0].delegates[0].id,
125122
QnnpackBackend.__name__,
126123
)
127124

128125
# Step 4: Run model and check outputs
129-
buffer = serialize_to_flatbuffer(program)
130-
executorch_module = _load_for_executorch_from_buffer(buffer)
126+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
131127
inputs_flattened, _ = tree_flatten(example_inputs)
132128
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
133129
ref_output = composite_model(*example_inputs)
@@ -199,20 +195,18 @@ def forward(self, x):
199195
example_inputs = (torch.rand(self.input_dims),)
200196

201197
composite_model(*example_inputs)
202-
program = (
198+
executorch_program: ExecutorchProgram = (
203199
exir.capture(composite_model, example_inputs, exir.CaptureConfig())
204200
.to_edge(EDGE_COMPILE_CONFIG)
205201
.to_executorch(config=EXECUTORCH_BACKEND_CONFIG)
206-
.program
207202
)
208203
self.assertEqual(
209-
program.execution_plan[0].delegates[0].id,
204+
executorch_program.program.execution_plan[0].delegates[0].id,
210205
QnnpackBackend.__name__,
211206
)
212207

213208
# Step 4: Run model and check outputs
214-
buffer = serialize_to_flatbuffer(program)
215-
executorch_module = _load_for_executorch_from_buffer(buffer)
209+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
216210
inputs_flattened, _ = tree_flatten(example_inputs)
217211
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
218212
ref_output = composite_model(*example_inputs)
@@ -274,20 +268,18 @@ def forward(self, x):
274268
example_inputs = (torch.rand(self.input_dims),)
275269

276270
composite_model(*example_inputs)
277-
program = (
271+
executorch_program: ExecutorchProgram = (
278272
exir.capture(composite_model, example_inputs, capture_config)
279273
.to_edge(EDGE_COMPILE_CONFIG)
280274
.to_executorch(config=EXECUTORCH_BACKEND_CONFIG)
281-
.program
282275
)
283276
self.assertEqual(
284-
program.execution_plan[0].delegates[0].id,
277+
executorch_program.program.execution_plan[0].delegates[0].id,
285278
QnnpackBackend.__name__,
286279
)
287280

288281
# Step 4: Run model and check outputs
289-
buffer = serialize_to_flatbuffer(program)
290-
executorch_module = _load_for_executorch_from_buffer(buffer)
282+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
291283
inputs_flattened, _ = tree_flatten(example_inputs)
292284
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
293285
ref_output = composite_model(*example_inputs)
@@ -359,20 +351,18 @@ def forward(self, x):
359351
example_inputs = (torch.rand(self.input_dims),)
360352

361353
composite_model(*example_inputs)
362-
program = (
354+
executorch_program: ExecutorchProgram = (
363355
exir.capture(composite_model, example_inputs, capture_config)
364356
.to_edge(EDGE_COMPILE_CONFIG)
365357
.to_executorch(config=EXECUTORCH_BACKEND_CONFIG)
366-
.program
367358
)
368359
self.assertEqual(
369-
program.execution_plan[0].delegates[0].id,
360+
executorch_program.program.execution_plan[0].delegates[0].id,
370361
QnnpackBackend.__name__,
371362
)
372363

373364
# Step 4: Run model and check outputs
374-
buffer = serialize_to_flatbuffer(program)
375-
executorch_module = _load_for_executorch_from_buffer(buffer)
365+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
376366
inputs_flattened, _ = tree_flatten(example_inputs)
377367
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
378368
ref_output = composite_model(*example_inputs)
@@ -433,20 +423,18 @@ def forward(self, x):
433423
example_inputs = (torch.rand(self.input_dims),)
434424

435425
composite_model(*example_inputs)
436-
program = (
426+
executorch_program: ExecutorchProgram = (
437427
exir.capture(composite_model, example_inputs, exir.CaptureConfig())
438428
.to_edge(EDGE_COMPILE_CONFIG)
439429
.to_executorch(config=EXECUTORCH_BACKEND_CONFIG)
440-
.program
441430
)
442431
self.assertEqual(
443-
program.execution_plan[0].delegates[0].id,
432+
executorch_program.program.execution_plan[0].delegates[0].id,
444433
QnnpackBackend.__name__,
445434
)
446435

447436
# Step 4: Run model and check outputs
448-
buffer = serialize_to_flatbuffer(program)
449-
executorch_module = _load_for_executorch_from_buffer(buffer)
437+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
450438
inputs_flattened, _ = tree_flatten(example_inputs)
451439
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
452440
ref_output = composite_model(*example_inputs)
@@ -515,7 +503,9 @@ def test_qnnpack_per_channel_dynamic_qlinear_via_partitioner(self):
515503
lowered_module.exported_program.graph_module.code
516504
)
517505

518-
program = lowered_module.to_executorch(config=EXECUTORCH_BACKEND_CONFIG).program
506+
executorch_program: ExecutorchProgram = lowered_module.to_executorch(
507+
config=EXECUTORCH_BACKEND_CONFIG
508+
)
519509

520510
# TODO(T143084047)
521511
# class CompositeModule(torch.nn.Module):
@@ -530,23 +520,21 @@ def test_qnnpack_per_channel_dynamic_qlinear_via_partitioner(self):
530520
# example_inputs = (torch.rand(self.input_dims),)
531521

532522
# composite_model(*example_inputs)
533-
# program = (
523+
# executorch_program: ExecutorchProgram = (
534524
# exir.capture(
535525
# composite_model, example_inputs, exir.CaptureConfig()
536526
# )
537527
# .to_edge(EDGE_COMPILE_CONFIG)
538528
# .to_executorch()
539-
# .program
540529
# )
541530

542531
self.assertEqual(
543-
program.execution_plan[0].delegates[0].id,
532+
executorch_program.program.execution_plan[0].delegates[0].id,
544533
QnnpackBackend.__name__,
545534
)
546535

547536
# Step 4: Run model and check outputs
548-
buffer = serialize_to_flatbuffer(program)
549-
executorch_module = _load_for_executorch_from_buffer(buffer)
537+
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
550538
inputs_flattened, _ = tree_flatten(example_inputs)
551539
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
552540
ref_output = captured_mod(*example_inputs)

0 commit comments

Comments
 (0)