Skip to content

Replace view copy with view (3/3) #2463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/selective_build/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ add_executable(selective_build_test ${_executor_runner__srcs})
if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
target_link_options(selective_build_test PRIVATE "LINKER:--gc-sections")
endif()
target_link_libraries(selective_build_test executorch gflags select_build_lib)
target_link_libraries(
selective_build_test PRIVATE executorch gflags select_build_lib
)
target_link_options_shared_lib(select_build_lib)
target_link_options_shared_lib(executorch)
target_compile_options(selective_build_test PUBLIC ${_common_compile_options})

# Print all summary
Expand Down
4 changes: 4 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ class ExecutorchBackendConfig:
# be a power of 2. If not provided, uses the value in the schema file.
delegate_alignment: Optional[int] = None
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()

# If set to true, view_copy operations will be converted to lightweight
# view operations in the ET runtime
remove_view_copy: bool = True
29 changes: 29 additions & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,32 @@ def _emit_control_flow(
)
)

def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
assert len(args) == 2

self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
out_arg = self._emit_argument(
self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6]
)

op_idx, op = self._get_operator(
name="executorch_prim::et_view",
overload="default",
)
kernel = Instruction(
KernelCall(
op_idx,
args=[
self_arg.id,
size_arg.id,
out_arg.id,
],
)
)
self.chain.instructions.append(kernel)
return out_arg

def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
"""Updates the debug handle information for the current node.
Expand Down Expand Up @@ -1198,6 +1224,9 @@ def call_function(
assert len(args) == 1
return self._emit_spec(self.node.meta["spec"])

elif target == memory.view:
return self._emit_view(args)

elif target == memory.free:
assert len(args) == 1
# pyre-ignore
Expand Down
16 changes: 12 additions & 4 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
edge = to_edge(export(f, inputs))

removed_ops = ["aten::relu_", "aten::view"]
expected_ops = ["aten::sin", "aten::relu", "aten::max", "aten::view_copy"]
expected_ops = [
"aten::sin",
"aten::relu",
"aten::max",
"executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False
]

for opname in removed_ops:
self.assertEqual(
self.count_node(edge.exported_program().graph_module, opname), 0
)
for opname in expected_ops:
self.assertTrue(
self.count_node(edge.exported_program().graph_module, opname) >= 1
)
if (
opname != "executorch_prim::et_view"
): # et_view appears as call_function with target = memory.view in graph
self.assertTrue(
self.count_node(edge.exported_program().graph_module, opname) >= 1
)

program = edge.to_executorch().executorch_program
for opname in removed_ops:
Expand Down
9 changes: 8 additions & 1 deletion exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def collect_specs_from_nodes( # noqa: C901
or node.target
in [
memory.alloc,
memory.view,
operator.getitem,
torch.ops.higher_order.cond,
exir_while,
Expand Down Expand Up @@ -534,7 +535,13 @@ def get_node_tensor_specs(
has no tensor specs.
"""
# get tensor specs
specs = node.meta.get("spec")
if node.target == memory.view:
base = node.args[0]
assert isinstance(base, torch.fx.Node)
specs = base.meta.get("spec")
else:
specs = node.meta.get("spec")

if isinstance(specs, TensorSpec):
specs = [specs]
if not isinstance(specs, (list, tuple)):
Expand Down
1 change: 1 addition & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
# we won't see it in the input graph to the to_out_variant pass, unless
# it's retraced after running to_out_variant with the first trace.
memory.alloc,
memory.view,
executorch_call_delegate,
torch.ops.aten.copy_.default,
}
Expand Down
Loading