Skip to content

: constant fold None #10762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def is_const(
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return True
elif arg is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was the problem before that const prop would skip nodes with None in the args?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think the issue is None is not a constant, and so if an op has args that include constants and None, it doesn't get constant folded.

return True
elif not isinstance(arg, torch.fx.Node):
return False
elif arg in const_node_to_tensor:
Expand Down
31 changes: 31 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,3 +1823,34 @@ def _do_checks(
self.assertTrue(
torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
)

def test_constant_prop_pass_none(self) -> None:
"""
This checks that None arguments are treated as constants in constant_prop_pass.
"""

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.cst = torch.ones(3, 3, 3, dtype=torch.int8)
self.w = torch.ones(3, 3, 3, dtype=torch.int8)

def forward(self, x):
# Note: using e.g aten.linear would not work as None is not in the graph
a = torch.ops.aten.convolution.default(
self.cst, self.w, None, [1], [0], [1], False, [0], 1
)
return a + x

mod = M()
x = torch.randn([3, 3, 3])
mod(x)
edge = to_edge(
export(mod, (x,), strict=True),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
# 2 constants: self.w and self.cst
self.assertEqual(2, len(edge.exported_program().constants))
pass_result = constant_prop_pass(edge.exported_program())
# 1 constant: a (= self.w @ self.cst)
self.assertEqual(1, len(pass_result.constants))