Skip to content

Cleanup memory passes tests. #7788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 161 additions & 1 deletion backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import logging
import math
import unittest
from typing import cast

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
Expand Down Expand Up @@ -110,7 +112,121 @@ def forward(self, x):


class TestMemTransform(unittest.TestCase):
def test_optimize_cat(self):
def _verify_cat_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
dim_offset = 0
for arg in cast(list[torch.fx.Node], node.args[0]):
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
arg_spec.mem_offset,
spec.mem_offset + dim_offset * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}",
)
dim_offset += arg_spec.shape[dim]

def _verify_slice_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
start: int = (
cast(int, node.args[2])
if (len(node.args) > 2 and node.args[2] is not None)
else 0
)
arg = cast(torch.fx.Node, node.args[0])
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
spec.mem_offset,
arg_spec.mem_offset + start * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for slice on {dim=}, but output has {spec.mem_offset=}",
)

def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim:]) * spec.dtype.itemsize
index: int = (
cast(int, node.args[2])
if (len(node.args) > 2 and node.args[2] is not None)
else 0
)
arg = cast(torch.fx.Node, node.args[0])
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
spec.mem_offset,
arg_spec.mem_offset + index * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} for select on {dim=} {index=}, "
f"but output has {spec.mem_offset=}"
f"{spec=} {arg_spec=}",
)

def verify_nop_memory_alloc(self, graph_module):
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._cat_nop.out
):
self._verify_cat_nop_memory_alloc(node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out
):
self._verify_slice_nop_memory_alloc(node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._select_copy_nop.int_out
):
self._verify_select_nop_memory_alloc(node)

def test_optimize_cat_on_placeholders(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y))

x = torch.ones(3, 6)
y = torch.ones(2, 6)
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
# pass to run:
graph_module = (
compiler.export_to_executorch_gen_etrecord(
Cat(), (x, y), opt_level=2, mem_algo=1
)
.exported_program()
.graph_module
)
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
graph_module.graph.eliminate_dead_code()
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_outermost(self):
class OptimizeCatFeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -135,7 +251,9 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_non_outermost(self):
class OptimizeCatFeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -160,7 +278,9 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost(self):
class OptimizeCatInfeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -184,7 +304,9 @@ def forward(self, x, y):
# Assert that cat op is not optimized away, since the concat is not
# along the outermost dim
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost1(self):
class OptimizeCatInfeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -209,6 +331,7 @@ def forward(self, x, y):
# offsets are not multiple of 8 bytes, and the cat is not the output
# of the graph.
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice(self):
class OptimizeCatSliceFeasible(torch.nn.Module):
Expand Down Expand Up @@ -237,6 +360,7 @@ def forward(self, x):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice_infeasible(self):
class OptimizeCatSliceInfeasible(torch.nn.Module):
Expand All @@ -262,6 +386,7 @@ def forward(self, x, y):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_Tensor(self):
class SliceTensor(torch.nn.Module):
Expand Down Expand Up @@ -323,6 +448,7 @@ def forward(self, x, y, z):
self.assertEqual(
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_Tensor(self):
class SelectTensor(torch.nn.Module):
Expand Down Expand Up @@ -387,6 +513,7 @@ def forward(self, x, y, z):
self.assertEqual(
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
)
self.verify_nop_memory_alloc(graph_module)

# TODO: Test fails due to memory planning
@unittest.expectedFailure
Expand Down Expand Up @@ -416,6 +543,32 @@ def forward(self, x, y):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_then_slice_on_mutable_buffer(self):
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
zeros = torch.zeros(padding_shape)
self.register_buffer("padding", zeros)

def forward(self, x, y):
x = x.view(3, 5)
cat = torch.ops.aten.cat((x, self.padding.clone()))
slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0])
self.padding.copy_(slice_copy)
return cat.view(-1) + y

x = torch.ones(15)
y = torch.ones(1)
et_prog_manager = compiler.export_to_executorch_gen_etrecord(
CatWithPadding((1, 5)), (x, y), opt_level=3
)
graph_module = et_prog_manager.exported_program().graph_module
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_view(self):
class CatViewFeasible(torch.nn.Module):
Expand All @@ -442,6 +595,7 @@ def forward(self, x, y):
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_repeated_args(self):
class CatViewInfeasible(torch.nn.Module):
Expand All @@ -465,6 +619,7 @@ def forward(self, x):
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_placeholder(self):
class CatViewInfeasible(torch.nn.Module):
Expand Down Expand Up @@ -492,6 +647,7 @@ def forward(self, x, y):
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat(self) -> None:
class Model(torch.nn.Module):
Expand Down Expand Up @@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
)
self.assertEqual(count_node(graph_module, memory.view), 2)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_copy(self) -> None:
class Model(torch.nn.Module):
Expand Down Expand Up @@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor:
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0
)
self.assertEqual(count_node(graph_module, memory.view), 2)
self.verify_nop_memory_alloc(graph_module)

def test_cat_then_cat(self) -> None:
class Model(torch.nn.Module):
Expand All @@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor:
graph_module.print_readable()
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2)
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_view_for_unallocated_output(self):
class Model(torch.nn.Module):
Expand All @@ -602,3 +761,4 @@ def forward(self, x, y):
.graph_module
)
self.assertEqual(count_node(graph_module, memory.view), 1)
self.verify_nop_memory_alloc(graph_module)
Loading