Skip to content

Commit 89a24e0

Browse files
dont emit non mutable weights
Differential Revision: D61888453 Pull Request resolved: #4938
1 parent 89c499e commit 89a24e0

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

exir/passes/weights_to_outputs_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def weights_to_outputs_pass(
5353
break
5454
assert output_node is not None
5555

56-
# Get place holder nodes with gradients
56+
# Get input nodes that are weights with an associated gradient
5757
placeholder_nodes = [
5858
node
5959
for node in gm.graph.nodes
60-
if node.op == "placeholder" and node.target in inputs_to_params.keys()
60+
if node.op == "placeholder"
61+
and node.target in inputs_to_params.keys()
62+
and inputs_to_params[node.target] in grad_targets
6163
]
6264

6365
# Flag these placeholder nodes as having a gradient attached so that memory planning will operate on them.

exir/tests/test_joint_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ class Module(torch.nn.Module):
2626
def __init__(self):
2727
super().__init__()
2828
self.linear = torch.nn.Linear(3, 3)
29+
self.linear_no_train = torch.nn.Linear(3, 3)
30+
for param in self.linear_no_train.parameters():
31+
param.requires_grad = False
2932
self.loss = torch.nn.CrossEntropyLoss()
3033

3134
def forward(self, x, y):
32-
return self.loss(self.linear(x).softmax(dim=0), y)
35+
return self.loss(self.linear_no_train(self.linear(x)).softmax(dim=0), y)
3336

3437
m = Module()
3538
example_inputs = (torch.ones(3), torch.tensor([1.0, 0.0, 0.0]))

0 commit comments

Comments
 (0)