Skip to content

Commit 8715897

Browse files
hsharma35facebook-github-bot
authored andcommitted
Cleanup memory passes tests. (#7788)
Summary: Add verifiers for memory allocation. Reviewed By: zonglinpeng, mcremon-meta Differential Revision: D68446633
1 parent 466d98f commit 8715897

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
import logging
34
import math
45
import unittest
6+
from typing import cast
57

68
import executorch.backends.cadence.aot.ops_registrations # noqa
79
import torch
@@ -110,7 +112,89 @@ def forward(self, x):
110112

111113

112114
class TestMemTransform(unittest.TestCase):
113-
def test_optimize_cat(self):
115+
def verify_cat_nop_memory_alloc(self, cat_node: torch.fx.Node) -> None:
116+
spec = cat_node.meta.get("spec", None)
117+
self.assertIsNotNone(spec)
118+
dim: int = cast(int, cat_node.args[1]) if len(cat_node.args) > 1 else 0
119+
cat_outer_size = math.prod(spec.shape[:dim])
120+
self.assertEqual(
121+
cat_outer_size,
122+
1,
123+
f"{cat_node=} has wrong outer size: {cat_outer_size=}, expected 1.",
124+
)
125+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
126+
dim_offset = 0
127+
for arg in cast(list[torch.fx.Node], cat_node.args[0]):
128+
arg_spec = arg.meta.get("spec", None)
129+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
130+
self.assertEqual(
131+
arg_spec.mem_offset,
132+
spec.mem_offset + dim_offset * inner_dim_elements,
133+
f"{arg=} for node {cat_node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}",
134+
)
135+
dim_offset += arg_spec.shape[dim]
136+
137+
def verify_slice_nop_memory_alloc(self, slice_node: torch.fx.Node) -> None:
138+
spec = slice_node.meta.get("spec", None)
139+
self.assertIsNotNone(spec)
140+
dim: int = cast(int, slice_node.args[1]) if len(slice_node.args) > 1 else 0
141+
cat_outer_size = math.prod(spec.shape[:dim])
142+
self.assertEqual(
143+
cat_outer_size,
144+
1,
145+
f"{slice_node=} has wrong outer size: {cat_outer_size=}, expected 1.",
146+
)
147+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
148+
start: int = (
149+
cast(int, slice_node.args[2])
150+
if (len(slice_node.args) > 2 and slice_node.args[2] is not None)
151+
else 0
152+
)
153+
arg = cast(torch.fx.Node, slice_node.args[0])
154+
arg_spec = arg.meta.get("spec", None)
155+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
156+
self.assertEqual(
157+
spec.mem_offset,
158+
arg_spec.mem_offset + start * inner_dim_elements,
159+
f"{arg=} for node {slice_node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for cat on {dim=}, but output has {spec.mem_offset=}",
160+
)
161+
162+
def verify_nop_memory_alloc(self, graph_module):
163+
for cat_node in graph_module.graph.find_nodes(
164+
op="call_function", target=torch.ops.aten._cat_nop.out
165+
):
166+
self.verify_cat_nop_memory_alloc(cat_node)
167+
168+
for slice_node in graph_module.graph.find_nodes(
169+
op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out
170+
):
171+
self.verify_slice_nop_memory_alloc(slice_node)
172+
173+
def test_optimize_cat_on_placeholders(self):
174+
class Cat(torch.nn.Module):
175+
def forward(self, x, y):
176+
return torch.ops.aten.cat((x, y))
177+
178+
x = torch.ones(3, 6)
179+
y = torch.ones(2, 6)
180+
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
181+
# pass to run:
182+
graph_module = (
183+
compiler.export_to_executorch_gen_etrecord(
184+
Cat(), (x, y), opt_level=2, mem_algo=1
185+
)
186+
.exported_program()
187+
.graph_module
188+
)
189+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
190+
graph_module.graph.eliminate_dead_code()
191+
# Assert that cat op is optimized away
192+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
193+
# Assert that cat op is replaced by its nop version post optimization
194+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
195+
self.verify_nop_memory_alloc(graph_module)
196+
197+
def test_optimize_cat_outermost(self):
114198
class OptimizeCatFeasible1(torch.nn.Module):
115199
def forward(self, x, y):
116200
x1 = torch.add(x, 2.4, 3.1)
@@ -135,7 +219,9 @@ def forward(self, x, y):
135219
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
136220
# Assert that cat op is replaced by its nop version post optimization
137221
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
222+
self.verify_nop_memory_alloc(graph_module)
138223

224+
def test_optimize_cat_non_outermost(self):
139225
class OptimizeCatFeasible2(torch.nn.Module):
140226
def forward(self, x, y):
141227
x1 = torch.add(x, 2.4, 3.1)
@@ -160,7 +246,9 @@ def forward(self, x, y):
160246
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
161247
# Assert that cat op is replaced by its nop version post optimization
162248
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
249+
self.verify_nop_memory_alloc(graph_module)
163250

251+
def test_no_optimize_cat_non_outermost(self):
164252
class OptimizeCatInfeasible1(torch.nn.Module):
165253
def forward(self, x, y):
166254
x1 = torch.add(x, 2.4, 3.1)
@@ -184,7 +272,9 @@ def forward(self, x, y):
184272
# Assert that cat op is not optimized away, since the concat is not
185273
# along the outermost dim
186274
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
275+
self.verify_nop_memory_alloc(graph_module)
187276

277+
def test_no_optimize_cat_non_outermost1(self):
188278
class OptimizeCatInfeasible2(torch.nn.Module):
189279
def forward(self, x, y):
190280
x1 = torch.add(x, 2.4, 3.1)
@@ -209,6 +299,7 @@ def forward(self, x, y):
209299
# offsets are not multiple of 8 bytes, and the cat is not the output
210300
# of the graph.
211301
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
302+
self.verify_nop_memory_alloc(graph_module)
212303

213304
def test_optimize_cat_with_slice(self):
214305
class OptimizeCatSliceFeasible(torch.nn.Module):
@@ -237,6 +328,7 @@ def forward(self, x):
237328
graph_module.graph.eliminate_dead_code()
238329
# Assert that cat op is optimized away
239330
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
331+
self.verify_nop_memory_alloc(graph_module)
240332

241333
def test_optimize_cat_with_slice_infeasible(self):
242334
class OptimizeCatSliceInfeasible(torch.nn.Module):
@@ -262,6 +354,7 @@ def forward(self, x, y):
262354
graph_module.graph.eliminate_dead_code()
263355
# Assert that cat op is not optimized away
264356
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
357+
self.verify_nop_memory_alloc(graph_module)
265358

266359
def test_optimize_slice_Tensor(self):
267360
class SliceTensor(torch.nn.Module):
@@ -323,6 +416,7 @@ def forward(self, x, y, z):
323416
self.assertEqual(
324417
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
325418
)
419+
self.verify_nop_memory_alloc(graph_module)
326420

327421
def test_optimize_select_Tensor(self):
328422
class SelectTensor(torch.nn.Module):
@@ -387,6 +481,7 @@ def forward(self, x, y, z):
387481
self.assertEqual(
388482
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
389483
)
484+
self.verify_nop_memory_alloc(graph_module)
390485

391486
# TODO: Test fails due to memory planning
392487
@unittest.expectedFailure
@@ -416,6 +511,32 @@ def forward(self, x, y):
416511
graph_module.graph.eliminate_dead_code()
417512
# Assert that cat op is not optimized away
418513
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
514+
self.verify_nop_memory_alloc(graph_module)
515+
516+
def test_optimize_cat_then_slice_on_mutable_buffer(self):
517+
class CatWithPadding(torch.nn.Module):
518+
def __init__(self, padding_shape):
519+
super().__init__()
520+
zeros = torch.zeros(padding_shape)
521+
self.register_buffer("padding", zeros)
522+
523+
def forward(self, x, y):
524+
x = x.view(3, 5)
525+
cat = torch.ops.aten.cat((x, self.padding.clone()))
526+
slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0])
527+
self.padding.copy_(slice_copy)
528+
return cat.view(-1) + y
529+
530+
x = torch.ones(15)
531+
y = torch.ones(1)
532+
et_prog_manager = compiler.export_to_executorch_gen_etrecord(
533+
CatWithPadding((1, 5)), (x, y), opt_level=3
534+
)
535+
graph_module = et_prog_manager.exported_program().graph_module
536+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
537+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
538+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
539+
self.verify_nop_memory_alloc(graph_module)
419540

420541
def test_optimize_cat_with_view(self):
421542
class CatViewFeasible(torch.nn.Module):
@@ -442,6 +563,7 @@ def forward(self, x, y):
442563
# Assert that cat op is optimized away
443564
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
444565
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
566+
self.verify_nop_memory_alloc(graph_module)
445567

446568
def test_no_optimize_cat_with_repeated_args(self):
447569
class CatViewInfeasible(torch.nn.Module):
@@ -465,6 +587,7 @@ def forward(self, x):
465587
# Assert that cat op is not optimized away
466588
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
467589
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
590+
self.verify_nop_memory_alloc(graph_module)
468591

469592
def test_no_optimize_cat_with_placeholder(self):
470593
class CatViewInfeasible(torch.nn.Module):
@@ -492,6 +615,7 @@ def forward(self, x, y):
492615
# Assert that cat op is not optimized away
493616
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
494617
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
618+
self.verify_nop_memory_alloc(graph_module)
495619

496620
def test_no_optimize_cat(self) -> None:
497621
class Model(torch.nn.Module):
@@ -522,6 +646,7 @@ def forward(self, x) -> torch.Tensor:
522646
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
523647
)
524648
self.assertEqual(count_node(graph_module, memory.view), 2)
649+
self.verify_nop_memory_alloc(graph_module)
525650

526651
def test_optimize_slice_copy(self) -> None:
527652
class Model(torch.nn.Module):
@@ -553,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
553678
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0
554679
)
555680
self.assertEqual(count_node(graph_module, memory.view), 2)
681+
self.verify_nop_memory_alloc(graph_module)
556682

557683
def test_cat_then_cat(self) -> None:
558684
class Model(torch.nn.Module):
@@ -579,6 +705,7 @@ def forward(self, x) -> torch.Tensor:
579705
graph_module.print_readable()
580706
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2)
581707
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
708+
self.verify_nop_memory_alloc(graph_module)
582709

583710
def test_view_for_unallocated_output(self):
584711
class Model(torch.nn.Module):
@@ -602,3 +729,4 @@ def forward(self, x, y):
602729
.graph_module
603730
)
604731
self.assertEqual(count_node(graph_module, memory.view), 1)
732+
self.verify_nop_memory_alloc(graph_module)

0 commit comments

Comments
 (0)