Skip to content

Use GraphBuilder in unit tests for ops removal. #11010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 99 additions & 165 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
Expand Down Expand Up @@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase):
)
@torch.no_grad()
def test_remove_to_ops(self, shape: Tuple[int]):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.to(x, dtype=torch.float32)

model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveToOpsPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
x = builder.call_operator(
op=exir_ops.edge.aten.to.dtype,
args=(x, torch.float32),
)
builder.output([x])
original = builder.get_graph_module()
graph_after_passes = cast(PassResult, RemoveToOpsPass()(original)).graph_module

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.to.dtype),
Expand All @@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor):
)
@torch.no_grad()
def test_remove_nop_add_op_pass(self, shape: Tuple[int]):
class FullX(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.add(torch.full(shape, 0), t)

class FullY(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.add(t, torch.full(shape, 0))

model = FullX()
t = torch.full(shape, 3)
graph_module = export_to_edge(model, (t,)).exported_program().graph_module

p = RemoveNopAddOpPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
0,
)

model = FullY()
graph_module = export_to_edge(model, (t,)).exported_program().graph_module

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
zeros = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=(shape, 0)
)
left_add = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(zeros, x),
)
right_add = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(left_add, zeros),
)
builder.output([right_add])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopAddOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
0,
Expand All @@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor):
)
@torch.no_grad()
def test_remove_nop_mul_op_pass(self, shape: Tuple[int]):
class FullX(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.mul(torch.full(shape, 0), t)

class FullY(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.mul(t, torch.full(shape, 0))

model = FullX()
t = torch.full(shape, 3)
graph_module = export_to_edge(model, (t,)).exported_program().graph_module

p = RemoveNopMulOpPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
0,
)

model = FullY()
graph_module = export_to_edge(model, (t,)).exported_program().graph_module

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
zeros = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=(shape, 0)
)
left_mul = builder.call_operator(
op=exir_ops.edge.aten.mul.Tensor,
args=(zeros, x),
)
right_mul = builder.call_operator(
op=exir_ops.edge.aten.mul.Tensor,
args=(left_mul, zeros),
)
builder.output([right_mul])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopMulOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
0,
Expand All @@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor):
)
@torch.no_grad()
def test_remove_alias_copy(self, shape: Tuple[int]):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.alias_copy(x)

model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

p = RemoveAliasCopyOpPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
alias = builder.call_operator(
op=exir_ops.edge.aten.alias_copy.default, args=(x,)
)
builder.output([alias])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveAliasCopyOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default),
0,
Expand All @@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor):
)
@torch.no_grad()
def test_remove_detach_copy(self, shape: Tuple[int]):
# aten::detach is converted to aten::alias_copy after functionalization & decomposition.
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.detach_copy(x)

model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

p = RemoveDetachCopyPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
detach = builder.call_operator(
op=exir_ops.edge.aten.detach_copy.default, args=(x,)
)
builder.output([detach])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveDetachCopyPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default),
0,
Expand All @@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor):
def test_remove_zero_sized_constant_pad_nd(
self, shape: Tuple[int], padding: Tuple[int]
):
# F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
class Padding(torch.nn.Module):
def __init__(self):
super().__init__()
self.padding = padding

def forward(self, x: torch.Tensor):
return F.pad(x, self.padding)

model = Padding()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

p = RemoveZeroSizedConstantPadNd()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
pad = builder.call_operator(
op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding)
)
builder.output([pad])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveZeroSizedConstantPadNd()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default),
0,
)

def test_remove_expand(self):
class Expand(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.expand_copy(x, [2, 3, 5])

x = torch.ones(2, 3, 5)
p = RemoveNopExpandOpPass()
graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module
graph_module = p(graph_module).graph_module
# Assert that expand op is optimized away, since it is a nop
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn([2, 3, 5], dtype=torch.float32))
expand = builder.call_operator(
op=exir_ops.edge.aten.expand_copy.default, args=(x, [2, 3, 5])
)
builder.output([expand])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopExpandOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0
count_node(graph_after_passes, exir_ops.edge.aten.expand_copy.default), 0
)

def test_remove_zero_arg_cat(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y), 0)

x = torch.ones(1, 0, 3, 5)
y = torch.ones(2, 0, 3, 5)
graph_module = (
compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module
)
# Assert that cat op is optimized away, since it concatenates
# two zero-sized tensors
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

def test_remove_single_arg_cat(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
z = torch.ones(0, 5)
# z is an empty tensor, and concatenation of x with z will
# be x. So we can safely eliminate the following cat op.
x1 = torch.ops.aten.cat((x, z))
x2 = torch.add(x1, 2.4, 3.1)
y1 = torch.add(y, 1, 2)
return torch.add(x2, y1)

x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that x1 is optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)

def test_remove_zero_sized_cat(self):
class Cat(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim

def forward(self, tensors):
return torch.cat(tensors, self.dim)

shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127

in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes]

model = Cat(dim)
graph_module = (
export_to_edge(model, (in_tensors,)).exported_program().graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn([1, 0, 3, 5], dtype=torch.float32))
y = builder.placeholder("y", torch.randn([2, 0, 3, 5], dtype=torch.float32))
concat = builder.call_operator(
op=exir_ops.edge.aten.cat.default, args=([x, y], 0)
)
builder.output([concat])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveZeroSizedCatArgsPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0
)
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)

def test_remove_clone(self):
class Clone(torch.nn.Module):
Expand Down
Loading