Skip to content

Commit cd320e0

Browse files
metascroyfacebook-github-bot
authored andcommitted
Replace view copy with view (3/3) (#2463)
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827305), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Reviewed By: larryliu0820 Differential Revision: D54827438
1 parent 71c619e commit cd320e0

File tree

13 files changed

+504
-122
lines changed

13 files changed

+504
-122
lines changed

examples/selective_build/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ add_executable(selective_build_test ${_executor_runner__srcs})
118118
if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
119119
target_link_options(selective_build_test PRIVATE "LINKER:--gc-sections")
120120
endif()
121-
target_link_libraries(selective_build_test executorch gflags select_build_lib)
121+
target_link_libraries(selective_build_test PRIVATE executorch gflags select_build_lib)
122+
target_link_options_shared_lib(selective_build_test)
123+
target_link_options_shared_lib(select_build_lib)
124+
target_link_options_shared_lib(executorch)
122125
target_compile_options(selective_build_test PUBLIC ${_common_compile_options})
123126

124127
# Print all summary

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,7 @@ class ExecutorchBackendConfig:
7575
# be a power of 2. If not provided, uses the value in the schema file.
7676
delegate_alignment: Optional[int] = None
7777
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()
78+
79+
# If set to true, view_copy operations will be converted to lightweight
80+
# view operations in the ET runtime
81+
remove_view_copy: bool = True

exir/emit/_emitter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,32 @@ def _emit_control_flow(
844844
)
845845
)
846846

847+
def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
848+
assert len(args) == 2
849+
850+
self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
851+
size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
852+
out_arg = self._emit_argument(
853+
self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6]
854+
)
855+
856+
op_idx, op = self._get_operator(
857+
name="executorch_prim::et_view",
858+
overload="default",
859+
)
860+
kernel = Instruction(
861+
KernelCall(
862+
op_idx,
863+
args=[
864+
self_arg.id,
865+
size_arg.id,
866+
out_arg.id,
867+
],
868+
)
869+
)
870+
self.chain.instructions.append(kernel)
871+
return out_arg
872+
847873
def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
848874
"""Updates the debug handle information for the current node.
849875
@@ -1198,6 +1224,9 @@ def call_function(
11981224
assert len(args) == 1
11991225
return self._emit_spec(self.node.meta["spec"])
12001226

1227+
elif target == memory.view:
1228+
return self._emit_view(args)
1229+
12011230
elif target == memory.free:
12021231
assert len(args) == 1
12031232
# pyre-ignore

exir/emit/test/test_emit.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265265
edge = to_edge(export(f, inputs))
266266

267267
removed_ops = ["aten::relu_", "aten::view"]
268-
expected_ops = ["aten::sin", "aten::relu", "aten::max", "aten::view_copy"]
268+
expected_ops = [
269+
"aten::sin",
270+
"aten::relu",
271+
"aten::max",
272+
"executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False
273+
]
269274

270275
for opname in removed_ops:
271276
self.assertEqual(
272277
self.count_node(edge.exported_program().graph_module, opname), 0
273278
)
274279
for opname in expected_ops:
275-
self.assertTrue(
276-
self.count_node(edge.exported_program().graph_module, opname) >= 1
277-
)
280+
if (
281+
opname != "executorch_prim::et_view"
282+
): # et_view appears as call_function with target = memory.view in graph
283+
self.assertTrue(
284+
self.count_node(edge.exported_program().graph_module, opname) >= 1
285+
)
278286

279287
program = edge.to_executorch().executorch_program
280288
for opname in removed_ops:

exir/memory_planning.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def collect_specs_from_nodes( # noqa: C901
397397
or node.target
398398
in [
399399
memory.alloc,
400+
memory.view,
400401
operator.getitem,
401402
torch.ops.higher_order.cond,
402403
exir_while,
@@ -534,7 +535,13 @@ def get_node_tensor_specs(
534535
has no tensor specs.
535536
"""
536537
# get tensor specs
537-
specs = node.meta.get("spec")
538+
if node.target == memory.view:
539+
base = node.args[0]
540+
assert isinstance(base, torch.fx.Node)
541+
specs = base.meta.get("spec")
542+
else:
543+
specs = node.meta.get("spec")
544+
538545
if isinstance(specs, TensorSpec):
539546
specs = [specs]
540547
if not isinstance(specs, (list, tuple)):

exir/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
248248
# we won't see it in the input graph to the to_out_variant pass, unless
249249
# it's retraced after running to_out_variant with the first trace.
250250
memory.alloc,
251+
memory.view,
251252
executorch_call_delegate,
252253
torch.ops.aten.copy_.default,
253254
}

0 commit comments

Comments
 (0)