Skip to content

Commit 38e211f

Browse files
hsharma35facebook-github-bot
authored andcommitted
Do not constant prop for mutable buffers. (#7779)
Summary: Before this change, operations on mutable buffers will be constant prop and replaced with a constant tensor. This change avoids constant prop for mutable buffers. Differential Revision: D68371513
1 parent 948fba6 commit 38e211f

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
get_buffer,
1717
get_lifted_tensor_constant,
1818
get_param,
19-
is_buffer,
2019
is_lifted_tensor_constant,
2120
is_param,
2221
)
@@ -77,6 +76,15 @@ def get_data(
7776
return const_node_to_tensor[arg]
7877
return None
7978

79+
def is_constant_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool:
80+
"""Checks if the given node is a constant buffer."""
81+
82+
if node.target not in program.graph_signature.inputs_to_buffers:
83+
return False
84+
fqn = program.graph_signature.inputs_to_buffers[node.target]
85+
# if the buffer is mutated then record that
86+
return fqn not in program.graph_signature.buffers_to_mutate.values()
87+
8088

8189
def get_constant_placeholder_dict(
8290
exported_program: ExportedProgram,
@@ -85,15 +93,12 @@ def get_constant_placeholder_dict(
8593
Returns a dictionary of placeholder node -> constant tensor.
8694
"""
8795
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
88-
for node in exported_program.graph.nodes:
89-
if node.op != "placeholder":
90-
continue
91-
96+
for node in exported_program.graph.find_nodes(op="placeholder"):
9297
if is_param(exported_program, node):
9398
const_node_to_tensor[node] = cast(
9499
torch.Tensor, get_param(exported_program, node)
95100
)
96-
elif is_buffer(exported_program, node):
101+
elif is_constant_buffer(exported_program, node):
97102
const_node_to_tensor[node] = cast(
98103
torch.Tensor, get_buffer(exported_program, node)
99104
)

exir/tests/test_passes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,34 @@ def forward(self, x):
15941594
gm.code
15951595
)
15961596

1597+
def test_constant_prop_pass_for_mutable_buffers(self) -> None:
1598+
def count_adds(gm: torch.fx.GraphModule) -> int:
1599+
return len(
1600+
gm.graph.find_nodes(
1601+
op="call_function", target=exir_ops.edge.aten.add.Tensor
1602+
)
1603+
)
1604+
1605+
class MutableStateModule(torch.nn.Module):
1606+
def __init__(self):
1607+
super().__init__()
1608+
self.register_buffer("state", torch.zeros(1))
1609+
1610+
def forward(self, x):
1611+
x = x + self.state
1612+
# Add 1 (constant) to state.
1613+
self.state.add_(1)
1614+
return x
1615+
1616+
edge_manager = to_edge(
1617+
export(MutableStateModule(), (torch.zeros(1),), strict=True)
1618+
)
1619+
self.assertEqual(count_adds(edge_manager.exported_program().graph_module), 2)
1620+
edge_manager._edge_programs["forward"] = constant_prop_pass(
1621+
edge_manager._edge_programs["forward"]
1622+
)
1623+
self.assertEqual(count_adds(edge_manager.exported_program().graph_module), 2)
1624+
15971625
def test_constant_prop_pass_for_no_grad(self) -> None:
15981626
class LSTM(torch.nn.Module):
15991627
def __init__(self, input_size, hidden_size, num_layers):

0 commit comments

Comments
 (0)