Skip to content

Commit e114205

Browse files
metascroyfacebook-github-bot
authored andcommitted
Elide static view copies (#8197)
Summary: This adds an ExecuTorch config option to to elide static views. Reviewed By: JacobSzwejbka, larryliu0820, hsharma35 Differential Revision: D68984189 Pulled By: metascroy
1 parent b1d76c9 commit e114205

File tree

4 files changed

+65
-20
lines changed

4 files changed

+65
-20
lines changed

exir/capture/_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class ExecutorchBackendConfig:
8787

8888
# If set to true, view_copy operations will be converted to lightweight
8989
# view operations in the ET runtime
90+
# Moreover, static views will be elided from the ExecuTorch graph
9091
remove_view_copy: bool = True
9192

9293
# If set to true, all constant tensors will be stored in a separate file,

exir/emit/_emitter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,16 @@ def _emit_control_flow(
943943
def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
944944
assert len(args) == 2
945945

946+
# Elide the view if it is static and memory planned
947+
spec = self.node.meta["spec"]
948+
is_static = spec.is_static_shape_tensor
949+
is_memory_planned = (spec.mem_id is not None) and (spec.mem_offset is not None)
950+
is_memory_planned = is_memory_planned or (
951+
spec.const and spec.storage is not None
952+
)
953+
if is_static and is_memory_planned:
954+
return self._emit_spec(spec)
955+
946956
self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
947957
size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
948958
out_arg = self._emit_argument(

exir/emit/test/test_emit.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -331,29 +331,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
331331
"aten::sin",
332332
"aten::relu",
333333
"aten::max",
334-
"executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False
335334
]
336335

336+
def expected_view_ops(config):
337+
if config.remove_view_copy:
338+
return []
339+
else:
340+
return ["aten::view_copy"]
341+
337342
for opname in removed_ops:
338343
self.assertEqual(
339344
self.count_node(edge.exported_program().graph_module, opname), 0
340345
)
341346
for opname in expected_ops:
342-
if (
343-
opname != "executorch_prim::et_view"
344-
): # et_view appears as call_function with target = memory.view in graph
345-
self.assertTrue(
346-
self.count_node(edge.exported_program().graph_module, opname) >= 1
347-
)
348-
349-
program = edge.to_executorch().executorch_program
350-
for opname in removed_ops:
351347
self.assertTrue(
352-
all(op.name != opname for op in program.execution_plan[0].operators)
348+
self.count_node(edge.exported_program().graph_module, opname) >= 1
353349
)
354-
for opname in expected_ops:
350+
351+
for remove_view_copy in [True, False]:
352+
config = exir.ExecutorchBackendConfig(remove_view_copy=remove_view_copy)
353+
edge_copy = deepcopy(edge)
354+
program = edge_copy.to_executorch(config=config).executorch_program
355+
for opname in removed_ops:
356+
self.assertTrue(
357+
all(op.name != opname for op in program.execution_plan[0].operators)
358+
)
359+
for opname in expected_ops + expected_view_ops(config):
360+
self.assertTrue(
361+
any(op.name == opname for op in program.execution_plan[0].operators)
362+
)
355363
self.assertTrue(
356-
any(op.name == opname for op in program.execution_plan[0].operators)
364+
len(program.execution_plan[0].operators)
365+
== len(expected_ops + expected_view_ops(config))
357366
)
358367

359368
def test_operators_unique(self) -> None:

exir/tests/test_remove_view_copy.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,41 @@ def test_spec(self) -> None:
194194
self.assertEqual(plan.operators[2].name, "aten::view_copy")
195195

196196
instructions = plan.chains[0].instructions
197-
self.assertEqual(len(instructions), 7)
197+
self.assertEqual(len(instructions), 5)
198198

199199
self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2
200-
self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3
201-
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6
202-
self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7
203-
self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9
200+
self.assertEqual(instructions[1].instr_args.op_index, 1) # aten:mul @ idx6
201+
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx9
204202
self.assertEqual(
205-
instructions[5].instr_args.op_index, 2
203+
instructions[3].instr_args.op_index, 2
206204
) # aten:view_copy @ idx11
207205
self.assertEqual(
208-
instructions[6].instr_args.op_index, 2
206+
instructions[4].instr_args.op_index, 2
209207
) # aten:view_copy @ idx11
208+
209+
def test_elide_static_views_does_not_remove_dynamic_views(self) -> None:
210+
class TestModel(nn.Module):
211+
def __init__(self):
212+
super().__init__()
213+
214+
def forward(self, x):
215+
x = x + x
216+
x = x.view(-1, 1)
217+
return 2 * x
218+
219+
model = TestModel()
220+
model.eval()
221+
example_inputs = (torch.rand(5, 6),)
222+
dynamic_shapes = {"x": {0: torch.export.Dim("dim0", min=1, max=10)}}
223+
ep = torch.export.export(
224+
model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes
225+
)
226+
etpm = to_edge(ep).to_executorch(
227+
config=ExecutorchBackendConfig(
228+
remove_view_copy=True,
229+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True),
230+
),
231+
)
232+
plan = etpm.executorch_program.execution_plan[0]
233+
op_names = [op.name for op in plan.operators]
234+
self.assertTrue("executorch_prim::et_view" in op_names)

0 commit comments

Comments
 (0)