Skip to content

remove exir.capture from test_rpc.py #3102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions exir/backend/test/demos/rpc/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from executorch import exir
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
ExecutorBackendPartitioner,
Expand All @@ -20,6 +21,7 @@
from executorch.extension.pybindings.portable_lib import ( # @manual
_load_for_executorch_from_buffer,
)
from torch.export import export
from torch.utils._pytree import tree_flatten

"""
Expand Down Expand Up @@ -101,16 +103,15 @@ def test_delegate_whole_program(self):

simple_net = self.get_a_simple_net()
simple_net_input = simple_net.get_example_inputs()
exported_program = exir.capture(
simple_net, simple_net_input, exir.CaptureConfig()
).to_edge(
exir.EdgeCompileConfig(
exported_program = to_edge(
export(simple_net, simple_net_input),
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False,
)
),
)
# delegate the whole graph to the client executor
lowered_module = to_backend(
ExecutorBackend.__name__, exported_program.exported_program, []
ExecutorBackend.__name__, exported_program.exported_program(), []
)

class CompositeModule(torch.nn.Module):
Expand All @@ -123,11 +124,7 @@ def forward(self, *args):

composite_model = CompositeModule()

exec_prog = (
exir.capture(composite_model, simple_net_input, exir.CaptureConfig())
.to_edge()
.to_executorch()
)
exec_prog = to_edge(export(composite_model, simple_net_input)).to_executorch()

executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)

Expand Down Expand Up @@ -162,18 +159,14 @@ def forward(self, a, x, b):
model = Model()
inputs = (torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2))

exported_program = exir.capture(model, inputs, exir.CaptureConfig()).to_edge()
exported_program = to_edge(export(model, inputs))

# First lower to demo backend
demo_backend_lowered = exported_program
demo_backend_lowered.exported_program = to_backend(
exported_program.exported_program, AddMulPartitionerDemo()
)
demo_backend_lowered = exported_program.to_backend(AddMulPartitionerDemo())

# Then lower to executor backend
executor_backend_lowered = demo_backend_lowered
executor_backend_lowered.exported_program = to_backend(
demo_backend_lowered.exported_program, ExecutorBackendPartitioner()
executor_backend_lowered = demo_backend_lowered.to_backend(
ExecutorBackendPartitioner()
)

prog_buffer = executor_backend_lowered.to_executorch()
Expand Down