Skip to content

Commit e94831e

Browse files
metascroyfacebook-github-bot
authored andcommitted
Elide static view copies
Summary: This adds an ExecuTorch config option to to elide static views. Differential Revision: D68984189
1 parent 2cfba1a commit e94831e

File tree

4 files changed

+66
-19
lines changed

4 files changed

+66
-19
lines changed

exir/capture/_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class ExecutorchBackendConfig:
8787

8888
# If set to true, view_copy operations will be converted to lightweight
8989
# view operations in the ET runtime
90+
# Moreover, static views will be elided from the ExecuTorch graph
9091
remove_view_copy: bool = True
9192

9293
# If set to true, all constant tensors will be stored in a separate file,

exir/emit/_emitter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,18 @@ def _emit_control_flow(
943943
def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
944944
assert len(args) == 2
945945

946+
# Elide the view if it is static and memory planned
947+
spec = self.node.meta["spec"]
948+
is_static = spec.is_static_shape_tensor
949+
is_memory_planned = (spec.mem_id is not None) and (
950+
spec.mem_offset is not None
951+
)
952+
is_memory_planned = is_memory_planned or (
953+
spec.const and spec.storage is not None
954+
)
955+
if is_static and is_memory_planned:
956+
return self._emit_spec(spec)
957+
946958
self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
947959
size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
948960
out_arg = self._emit_argument(

exir/emit/test/test_emit.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -331,30 +331,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
331331
"aten::sin",
332332
"aten::relu",
333333
"aten::max",
334-
"executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False
335334
]
336335

336+
def expected_view_ops(config):
337+
if config.remove_view_copy:
338+
return []
339+
else:
340+
return ["aten::view_copy"]
341+
342+
337343
for opname in removed_ops:
338344
self.assertEqual(
339345
self.count_node(edge.exported_program().graph_module, opname), 0
340346
)
341347
for opname in expected_ops:
342-
if (
343-
opname != "executorch_prim::et_view"
344-
): # et_view appears as call_function with target = memory.view in graph
345348
self.assertTrue(
346349
self.count_node(edge.exported_program().graph_module, opname) >= 1
347350
)
348351

349-
program = edge.to_executorch().executorch_program
350-
for opname in removed_ops:
351-
self.assertTrue(
352-
all(op.name != opname for op in program.execution_plan[0].operators)
353-
)
354-
for opname in expected_ops:
355-
self.assertTrue(
356-
any(op.name == opname for op in program.execution_plan[0].operators)
352+
for remove_view_copy in [True, False]:
353+
config = exir.ExecutorchBackendConfig(
354+
remove_view_copy=remove_view_copy
357355
)
356+
edge_copy = deepcopy(edge)
357+
program = edge_copy.to_executorch(config=config).executorch_program
358+
for opname in removed_ops:
359+
self.assertTrue(
360+
all(op.name != opname for op in program.execution_plan[0].operators)
361+
)
362+
for opname in expected_ops + expected_view_ops(config):
363+
self.assertTrue(
364+
any(op.name == opname for op in program.execution_plan[0].operators)
365+
)
366+
self.assertTrue(len(program.execution_plan[0].operators) == len(expected_ops + expected_view_ops(config)))
358367

359368
def test_operators_unique(self) -> None:
360369
class OpRepeatedModule(torch.nn.Module):

exir/tests/test_remove_view_copy.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,41 @@ def test_spec(self) -> None:
194194
self.assertEqual(plan.operators[2].name, "aten::view_copy")
195195

196196
instructions = plan.chains[0].instructions
197-
self.assertEqual(len(instructions), 7)
197+
self.assertEqual(len(instructions), 5)
198198

199199
self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2
200-
self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3
201-
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6
202-
self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7
203-
self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9
200+
self.assertEqual(instructions[1].instr_args.op_index, 1) # aten:mul @ idx6
201+
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx9
204202
self.assertEqual(
205-
instructions[5].instr_args.op_index, 2
203+
instructions[3].instr_args.op_index, 2
206204
) # aten:view_copy @ idx11
207205
self.assertEqual(
208-
instructions[6].instr_args.op_index, 2
206+
instructions[4].instr_args.op_index, 2
209207
) # aten:view_copy @ idx11
208+
209+
def test_elide_static_views_does_not_remove_dynamic_views(self) -> None:
210+
class TestModel(nn.Module):
211+
def __init__(self):
212+
super().__init__()
213+
214+
def forward(self, x):
215+
x = x + x
216+
x = x.view(-1, 1)
217+
return 2 * x
218+
219+
model = TestModel()
220+
model.eval()
221+
example_inputs = (torch.rand(5, 6),)
222+
dynamic_shapes = {"x": {0: torch.export.Dim("dim0", min=1, max=10)}}
223+
ep = torch.export.export(model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes)
224+
etpm = to_edge(ep).to_executorch(
225+
config=ExecutorchBackendConfig(
226+
remove_view_copy=True,
227+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True),
228+
),
229+
)
230+
plan = etpm.executorch_program.execution_plan[0]
231+
op_names = [op.name for op in plan.operators]
232+
self.assertTrue("executorch_prim::et_view" in op_names)
233+
234+

0 commit comments

Comments
 (0)