Skip to content

Commit 224902c

Browse files
perfreddan80
authored andcommitted
Set requires_grad to avoid check for differentiable
The node.args that comes from nn.Parameters in the state_dict sometimes have the requires_grad property set to True. It's seems to stem already from the export stage and .eval() doesn't change the parameter. Address it here in the pass for now. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ie7425764f1f1865de0fc66e1d020f804c7e936b1
1 parent 99d5b80 commit 224902c

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ def call(self, graph_module: GraphModule) -> PassResult:
9898
for i, arg in enumerate(n.args):
9999
if not isinstance(arg, Node):
100100
continue
101+
102+
# Make sure arg has requires_grad set to False
103+
# For parameters that are not quantized, sometimes (i.e. convolution)
104+
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
105+
# causes the retracing of the graph to fail with:
106+
#
107+
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
108+
# E
109+
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
110+
# E Original traceback:
111+
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
112+
# E x = conv(x)
113+
#
114+
if arg.op == "placeholder":
115+
arg.meta["val"].requires_grad = False
116+
101117
if arg.target != dq_op:
102118
continue
103119

0 commit comments

Comments
 (0)