Skip to content

Commit eb21417

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
: constant fold None
Summary: Constant folding should fold 'None' and consider it a constant This goes with D74350331 and D74349918 but keeping things separate Differential Revision: D74350331
1 parent b875a7a commit eb21417

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def is_const(
6161
)
6262
elif isinstance(arg, _PRIMITIVE_TYPES):
6363
return True
64+
elif arg is None:
65+
return True
6466
elif not isinstance(arg, torch.fx.Node):
6567
return False
6668
elif arg in const_node_to_tensor:

exir/tests/test_passes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,3 +1823,31 @@ def _do_checks(
18231823
self.assertTrue(
18241824
torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
18251825
)
1826+
1827+
def test_constant_prop_pass_none(self) -> None:
1828+
"""
1829+
This checks that None arguments are treated as constants in constant_prop_pass.
1830+
"""
1831+
class M(torch.nn.Module):
1832+
def __init__(self):
1833+
super().__init__()
1834+
self.cst = torch.ones(3 , 3, 3, dtype=torch.int8)
1835+
self.w = torch.ones(3, 3, 3, dtype=torch.int8)
1836+
1837+
def forward(self, x):
1838+
# Note: using e.g aten.linear would not work as None is not in the graph
1839+
a = torch.ops.aten.convolution.default(self.cst, self.w, None, [1], [0], [1], False, [0], 1)
1840+
return a+x
1841+
1842+
mod = M()
1843+
x = torch.randn([3, 3, 3])
1844+
mod(x)
1845+
edge = to_edge(
1846+
export(mod, (x,), strict=True),
1847+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1848+
)
1849+
# 2 constants: self.w and self.cst
1850+
self.assertEqual(2, len(edge.exported_program().constants))
1851+
pass_result = constant_prop_pass(edge.exported_program())
1852+
# 1 constant: a (= self.w @ self.cst)
1853+
self.assertEqual(1, len(pass_result.constants))

0 commit comments

Comments
 (0)