Skip to content

Use GraphBuilder in memory passes unit tests. # 2 #11292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 213 additions & 102 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:

# Initializes the nodes metadata and runs the GenerateMemoryViewConstraints,
# GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
def run_memory_planning(self, original, alloc_graph_input=True) -> GraphModule:
def run_memory_planning(
self, original, opt_level=2, alloc_graph_input=True
) -> GraphModule:
graph_module = SpecPropPass().call(original).graph_module
return CadenceMemoryPlanning(
get_default_memory_config(),
opt_level=2,
opt_level=opt_level,
mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy
alloc_graph_input=alloc_graph_input,
)(graph_module).graph_module
Expand Down Expand Up @@ -535,130 +537,239 @@ def test_optimize_cat_with_slice_infeasible(self) -> None:
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_Tensor(self) -> None:
class SliceTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
# This slice should always be optimized, since x1 is not placeholder
# and the slice is along the outermost dim
t1 = torch.ops.aten.slice(x1, 0, 1, 2)
# This slice should not be optimized when alloc_graph_input=False,
# since y is a placeholder node
t2 = torch.ops.aten.slice(y, 0, 0, 1)
# This slice should be always optimized, since the dims before
# sliced dims are 1
z1 = torch.add(z, 2.4, 3.1)
t3 = torch.ops.aten.slice(z1, 1, 4, 5)
return (t1 + t2) * t3

x = torch.ones(3, 6)
y = torch.ones(2, 6)
z = torch.ones(1, 6)
# Run the memory planning pass and get the graph module
graph_module = (
compiler.export_to_executorch_gen_etrecord(
SliceTensor(),
(x, y, z),
opt_level=2,
mem_algo=1,
alloc_graph_input=False,
)
.exported_program()
.graph_module
def test_optimize_slice_outermost(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
to_add_to_x = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([3, 6], 123.0),
kwargs={"dtype": torch.float32},
)
add_x = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add_to_x),
)
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 0.0),
kwargs={"dtype": torch.float32},
)
# This slice should always be optimized, since add_x is not placeholder
# and the slice is along the outermost dim
slice_result = builder.call_operator(
op=torch.ops.aten.slice_copy.Tensor_out,
args=(
add_x,
0, # dim
1, # start
2, # end
1, # step
),
kwargs={"out": slice_out},
)
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
graph_module.graph.eliminate_dead_code()
# Assert that t2 is not optimized away
self.assertEqual(
count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out), 1
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_non_outermost(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(1, 6, dtype=torch.float32))
to_add_to_x = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 123.0),
kwargs={"dtype": torch.float32},
)
add_x = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add_to_x),
)
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 2], 0.0),
kwargs={"dtype": torch.float32},
)
# This slice should be always optimized, since the dims before
# sliced dims are 1.
slice_result = builder.call_operator(
op=torch.ops.aten.slice_copy.Tensor_out,
args=(
add_x,
1, # dim
4, # start
6, # end
1, # step
),
kwargs={"out": slice_out},
)
# Assert that t1 and t3 are optimized to slice_copy_nop veresion
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
graph_module.graph.eliminate_dead_code()
self.assertEqual(
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_depending_on_opt_level(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(2, 6, dtype=torch.float32))
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 0.0),
kwargs={"dtype": torch.float32},
)
# This slice should not be optimized when alloc_graph_input=False,
# since y is a placeholder node
slice_result = builder.call_operator(
op=torch.ops.aten.slice_copy.Tensor_out,
args=(
x,
0, # dim
0, # start
1, # end
1, # step
),
kwargs={"out": slice_out},
)
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(
original, opt_level=2, alloc_graph_input=False
)
graph_module.graph.eliminate_dead_code()
self.assertEqual(
count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out), 1
)
self.verify_nop_memory_alloc(graph_module)

# When we compile with alloc_graph_input=True, all the slice ops must
# be optimized.
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
# pass to run:
graph_module = (
compiler.export_to_executorch_gen_etrecord(
SliceTensor(),
(x, y, z),
opt_level=3,
mem_algo=1,
alloc_graph_input=True,
)
.exported_program()
.graph_module
# be optimized, which is available only at opt_level 2+.
graph_module = self.run_memory_planning(
original, opt_level=3, alloc_graph_input=True
)
graph_module.graph.eliminate_dead_code()
self.assertFalse(count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out))
self.assertEqual(
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_Tensor(self) -> None:
class SelectTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
# This select should always be optimized, since x1 is not
# placeholder, and the select is along the outermost dim
t1 = torch.select_copy(x1, 0, 1)
# This select should not be optimized if alloc_graph_input=False,
# since y is a placeholder node.
t2 = torch.select_copy(y, 0, 0)
# This select should always be optimized, since the dims before
# select dims are 1
z1 = torch.add(z, 2.4, 3.1)
t3 = torch.select(z1, 1, 4)
return (t1 + t2) * t3

x = torch.ones(3, 6)
y = torch.ones(2, 6)
z = torch.ones(1, 6)
# Optimizing select ops is only at opt_level 2+, and requires the memory planning
# pass to run:
graph_module = (
compiler.export_to_executorch_gen_etrecord(
SelectTensor(),
(x, y, z),
opt_level=2,
mem_algo=1,
alloc_graph_input=False,
)
.exported_program()
.graph_module
def test_optimize_select_outermost(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
to_add_to_x = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([3, 6], 123.0),
kwargs={"dtype": torch.float32},
)
add_x = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add_to_x),
)
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 0.0),
kwargs={"dtype": torch.float32},
)
# This select should always be optimized, since add_x is not placeholder
# and the select is along the outermost dim
slice_result = builder.call_operator(
op=torch.ops.aten.select_copy.int_out,
args=(
add_x,
0, # dim
1, # index
),
kwargs={"out": slice_out},
)
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
graph_module.graph.eliminate_dead_code()
# Assert that t2 is not optimized away
self.assertEqual(
count_node(graph_module, torch.ops.aten.select_copy.int_out), 1
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_non_outermost(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(1, 6, dtype=torch.float32))
to_add_to_x = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 123.0),
kwargs={"dtype": torch.float32},
)
add_x = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add_to_x),
)
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 2], 0.0),
kwargs={"dtype": torch.float32},
)
# This select should always be optimized, since the dims before
# select dims are 1
slice_result = builder.call_operator(
op=torch.ops.aten.select_copy.int_out,
args=(
add_x,
1, # dim
4, # index
),
kwargs={"out": slice_out},
)
# Assert that t1 and t3 are optimized to select_copy_nop veresion
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
graph_module.graph.eliminate_dead_code()
self.assertEqual(
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 2
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
)
# When we compile with alloc_graph_input=True, all the select ops must
# be optimized.
# Optimizing select ops is only at opt_level 2+, and requires the memory planning
# pass to run:
graph_module = (
compiler.export_to_executorch_gen_etrecord(
SelectTensor(),
(x, y, z),
opt_level=3,
mem_algo=1,
alloc_graph_input=True,
)
.exported_program()
.graph_module
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_depending_on_opt_level(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(2, 6, dtype=torch.float32))
slice_out = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1, 6], 0.0),
kwargs={"dtype": torch.float32},
)
# This select should not be optimized if alloc_graph_input=False,
# since y is a placeholder node.
slice_result = builder.call_operator(
op=torch.ops.aten.select_copy.int_out,
args=(
x,
0, # dim
0, # index
),
kwargs={"out": slice_out},
)
builder.output([slice_result])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(
original, opt_level=2, alloc_graph_input=False
)
graph_module.graph.eliminate_dead_code()
self.assertEqual(
count_node(graph_module, torch.ops.aten.select_copy.int_out), 0
count_node(graph_module, torch.ops.aten.select_copy.int_out), 1
)
self.verify_nop_memory_alloc(graph_module)

# When we compile with alloc_graph_input=True, all the slice ops must
# be optimized, which is available only at opt_level 2+.
graph_module = self.run_memory_planning(
original, opt_level=3, alloc_graph_input=True
)
graph_module.graph.eliminate_dead_code()
self.assertEqual(
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
)
self.verify_nop_memory_alloc(graph_module)

Expand Down
Loading