Skip to content

Commit 77d10b0

Browse files
hsharma35facebook-github-bot
authored andcommitted
Cleanup memory passes tests. (#7788)
Summary: Add verifiers for memory allocation. Reviewed By: zonglinpeng, mcremon-meta Differential Revision: D68446633
1 parent 466d98f commit 77d10b0

File tree

1 file changed

+159
-1
lines changed

1 file changed

+159
-1
lines changed

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
import logging
34
import math
45
import unittest
6+
from typing import cast
57

68
import executorch.backends.cadence.aot.ops_registrations # noqa
79
import torch
@@ -110,7 +112,119 @@ def forward(self, x):
110112

111113

112114
class TestMemTransform(unittest.TestCase):
113-
def test_optimize_cat(self):
115+
def verify_cat_nop_memory_alloc(self, cat_node: torch.fx.Node) -> None:
116+
spec = cat_node.meta.get("spec", None)
117+
self.assertIsNotNone(spec)
118+
dim: int = cast(int, cat_node.args[1]) if len(cat_node.args) > 1 else 0
119+
cat_outer_size = math.prod(spec.shape[:dim])
120+
self.assertEqual(
121+
cat_outer_size,
122+
1,
123+
f"{cat_node=} has wrong outer size: {cat_outer_size=}, expected 1.",
124+
)
125+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
126+
dim_offset = 0
127+
for arg in cast(list[torch.fx.Node], cat_node.args[0]):
128+
arg_spec = arg.meta.get("spec", None)
129+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
130+
self.assertEqual(
131+
arg_spec.mem_offset,
132+
spec.mem_offset + dim_offset * inner_dim_elements,
133+
f"{arg=} for node {cat_node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}",
134+
)
135+
dim_offset += arg_spec.shape[dim]
136+
137+
def verify_slice_nop_memory_alloc(self, slice_node: torch.fx.Node) -> None:
138+
spec = slice_node.meta.get("spec", None)
139+
self.assertIsNotNone(spec)
140+
dim: int = cast(int, slice_node.args[1]) if len(slice_node.args) > 1 else 0
141+
cat_outer_size = math.prod(spec.shape[:dim])
142+
self.assertEqual(
143+
cat_outer_size,
144+
1,
145+
f"{slice_node=} has wrong outer size: {cat_outer_size=}, expected 1.",
146+
)
147+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
148+
start: int = (
149+
cast(int, slice_node.args[2])
150+
if (len(slice_node.args) > 2 and slice_node.args[2] is not None)
151+
else 0
152+
)
153+
arg = cast(torch.fx.Node, slice_node.args[0])
154+
arg_spec = arg.meta.get("spec", None)
155+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
156+
self.assertEqual(
157+
spec.mem_offset,
158+
arg_spec.mem_offset + start * inner_dim_elements,
159+
f"{arg=} for node {slice_node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for cat on {dim=}, but output has {spec.mem_offset=}",
160+
)
161+
162+
def verify_select_nop_memory_alloc(self, select_node: torch.fx.Node) -> None:
163+
spec = select_node.meta.get("spec", None)
164+
self.assertIsNotNone(spec)
165+
dim: int = cast(int, select_node.args[1]) if len(select_node.args) > 1 else 0
166+
cat_outer_size = math.prod(spec.shape[:dim])
167+
self.assertEqual(
168+
cat_outer_size,
169+
1,
170+
f"{select_node=} has wrong outer size: {cat_outer_size=}, expected 1.",
171+
)
172+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
173+
index: int = (
174+
cast(int, select_node.args[2])
175+
if (len(select_node.args) > 2 and select_node.args[2] is not None)
176+
else 0
177+
)
178+
arg = cast(torch.fx.Node, select_node.args[0])
179+
arg_spec = arg.meta.get("spec", None)
180+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
181+
self.assertEqual(
182+
spec.mem_offset,
183+
arg_spec.mem_offset + index * inner_dim_elements,
184+
f"{arg=} for node {select_node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for cat on {dim=}, but output has {spec.mem_offset=}",
185+
)
186+
187+
def verify_nop_memory_alloc(self, graph_module):
188+
for cat_node in graph_module.graph.find_nodes(
189+
op="call_function", target=torch.ops.aten._cat_nop.out
190+
):
191+
self.verify_cat_nop_memory_alloc(cat_node)
192+
193+
for slice_node in graph_module.graph.find_nodes(
194+
op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out
195+
):
196+
self.verify_slice_nop_memory_alloc(slice_node)
197+
198+
for select_node in graph_module.graph.find_nodes(
199+
op="call_function", target=torch.ops.aten._select_copy_nop.Tensor_out
200+
):
201+
self.verify_select_nop_memory_alloc(slice_node)
202+
203+
def test_optimize_cat_on_placeholders(self):
204+
class Cat(torch.nn.Module):
205+
def forward(self, x, y):
206+
return torch.ops.aten.cat((x, y))
207+
208+
x = torch.ones(3, 6)
209+
y = torch.ones(2, 6)
210+
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
211+
# pass to run:
212+
graph_module = (
213+
compiler.export_to_executorch_gen_etrecord(
214+
Cat(), (x, y), opt_level=2, mem_algo=1
215+
)
216+
.exported_program()
217+
.graph_module
218+
)
219+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
220+
graph_module.graph.eliminate_dead_code()
221+
# Assert that cat op is optimized away
222+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
223+
# Assert that cat op is replaced by its nop version post optimization
224+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
225+
self.verify_nop_memory_alloc(graph_module)
226+
227+
def test_optimize_cat_outermost(self):
114228
class OptimizeCatFeasible1(torch.nn.Module):
115229
def forward(self, x, y):
116230
x1 = torch.add(x, 2.4, 3.1)
@@ -135,7 +249,9 @@ def forward(self, x, y):
135249
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
136250
# Assert that cat op is replaced by its nop version post optimization
137251
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
252+
self.verify_nop_memory_alloc(graph_module)
138253

254+
def test_optimize_cat_non_outermost(self):
139255
class OptimizeCatFeasible2(torch.nn.Module):
140256
def forward(self, x, y):
141257
x1 = torch.add(x, 2.4, 3.1)
@@ -160,7 +276,9 @@ def forward(self, x, y):
160276
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
161277
# Assert that cat op is replaced by its nop version post optimization
162278
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
279+
self.verify_nop_memory_alloc(graph_module)
163280

281+
def test_no_optimize_cat_non_outermost(self):
164282
class OptimizeCatInfeasible1(torch.nn.Module):
165283
def forward(self, x, y):
166284
x1 = torch.add(x, 2.4, 3.1)
@@ -184,7 +302,9 @@ def forward(self, x, y):
184302
# Assert that cat op is not optimized away, since the concat is not
185303
# along the outermost dim
186304
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
305+
self.verify_nop_memory_alloc(graph_module)
187306

307+
def test_no_optimize_cat_non_outermost1(self):
188308
class OptimizeCatInfeasible2(torch.nn.Module):
189309
def forward(self, x, y):
190310
x1 = torch.add(x, 2.4, 3.1)
@@ -209,6 +329,7 @@ def forward(self, x, y):
209329
# offsets are not multiple of 8 bytes, and the cat is not the output
210330
# of the graph.
211331
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
332+
self.verify_nop_memory_alloc(graph_module)
212333

213334
def test_optimize_cat_with_slice(self):
214335
class OptimizeCatSliceFeasible(torch.nn.Module):
@@ -237,6 +358,7 @@ def forward(self, x):
237358
graph_module.graph.eliminate_dead_code()
238359
# Assert that cat op is optimized away
239360
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
361+
self.verify_nop_memory_alloc(graph_module)
240362

241363
def test_optimize_cat_with_slice_infeasible(self):
242364
class OptimizeCatSliceInfeasible(torch.nn.Module):
@@ -262,6 +384,7 @@ def forward(self, x, y):
262384
graph_module.graph.eliminate_dead_code()
263385
# Assert that cat op is not optimized away
264386
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
387+
self.verify_nop_memory_alloc(graph_module)
265388

266389
def test_optimize_slice_Tensor(self):
267390
class SliceTensor(torch.nn.Module):
@@ -323,6 +446,7 @@ def forward(self, x, y, z):
323446
self.assertEqual(
324447
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
325448
)
449+
self.verify_nop_memory_alloc(graph_module)
326450

327451
def test_optimize_select_Tensor(self):
328452
class SelectTensor(torch.nn.Module):
@@ -387,6 +511,7 @@ def forward(self, x, y, z):
387511
self.assertEqual(
388512
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
389513
)
514+
self.verify_nop_memory_alloc(graph_module)
390515

391516
# TODO: Test fails due to memory planning
392517
@unittest.expectedFailure
@@ -416,6 +541,32 @@ def forward(self, x, y):
416541
graph_module.graph.eliminate_dead_code()
417542
# Assert that cat op is not optimized away
418543
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
544+
self.verify_nop_memory_alloc(graph_module)
545+
546+
def test_optimize_cat_then_slice_on_mutable_buffer(self):
547+
class CatWithPadding(torch.nn.Module):
548+
def __init__(self, padding_shape):
549+
super().__init__()
550+
zeros = torch.zeros(padding_shape)
551+
self.register_buffer("padding", zeros)
552+
553+
def forward(self, x, y):
554+
x = x.view(3, 5)
555+
cat = torch.ops.aten.cat((x, self.padding.clone()))
556+
slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0])
557+
self.padding.copy_(slice_copy)
558+
return cat.view(-1) + y
559+
560+
x = torch.ones(15)
561+
y = torch.ones(1)
562+
et_prog_manager = compiler.export_to_executorch_gen_etrecord(
563+
CatWithPadding((1, 5)), (x, y), opt_level=3
564+
)
565+
graph_module = et_prog_manager.exported_program().graph_module
566+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
567+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
568+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
569+
self.verify_nop_memory_alloc(graph_module)
419570

420571
def test_optimize_cat_with_view(self):
421572
class CatViewFeasible(torch.nn.Module):
@@ -442,6 +593,7 @@ def forward(self, x, y):
442593
# Assert that cat op is optimized away
443594
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
444595
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
596+
self.verify_nop_memory_alloc(graph_module)
445597

446598
def test_no_optimize_cat_with_repeated_args(self):
447599
class CatViewInfeasible(torch.nn.Module):
@@ -465,6 +617,7 @@ def forward(self, x):
465617
# Assert that cat op is not optimized away
466618
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
467619
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
620+
self.verify_nop_memory_alloc(graph_module)
468621

469622
def test_no_optimize_cat_with_placeholder(self):
470623
class CatViewInfeasible(torch.nn.Module):
@@ -492,6 +645,7 @@ def forward(self, x, y):
492645
# Assert that cat op is not optimized away
493646
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
494647
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
648+
self.verify_nop_memory_alloc(graph_module)
495649

496650
def test_no_optimize_cat(self) -> None:
497651
class Model(torch.nn.Module):
@@ -522,6 +676,7 @@ def forward(self, x) -> torch.Tensor:
522676
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
523677
)
524678
self.assertEqual(count_node(graph_module, memory.view), 2)
679+
self.verify_nop_memory_alloc(graph_module)
525680

526681
def test_optimize_slice_copy(self) -> None:
527682
class Model(torch.nn.Module):
@@ -553,6 +708,7 @@ def forward(self, x) -> torch.Tensor:
553708
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0
554709
)
555710
self.assertEqual(count_node(graph_module, memory.view), 2)
711+
self.verify_nop_memory_alloc(graph_module)
556712

557713
def test_cat_then_cat(self) -> None:
558714
class Model(torch.nn.Module):
@@ -579,6 +735,7 @@ def forward(self, x) -> torch.Tensor:
579735
graph_module.print_readable()
580736
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2)
581737
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
738+
self.verify_nop_memory_alloc(graph_module)
582739

583740
def test_view_for_unallocated_output(self):
584741
class Model(torch.nn.Module):
@@ -602,3 +759,4 @@ def forward(self, x, y):
602759
.graph_module
603760
)
604761
self.assertEqual(count_node(graph_module, memory.view), 1)
762+
self.verify_nop_memory_alloc(graph_module)

0 commit comments

Comments
 (0)