Skip to content

Commit 1819da5

Browse files
ydwu4facebook-github-bot
authored andcommitted
Forward fix NluStellaCapExportTest
Summary: Forward fix NluStellaCapExportTest. Currently, to_backend works with graph module and doesn't work well with ExportedProgram. Need to return the graph_module contained in ExportedProgram to fix the test. There is an assertion error raised in call_delegate_autograd: it finds some inputs require grad. In this diff, we delay the error message to when .backward() is called (probably will never be called since we're in inference mode, which is good since we only care about inference for call_delegate). I checked how nlu models generate input and find no inputs requires grad. I suspect that there are some transformations doesn't respect this requires_grad=False property. This deserves a separate investigation. Reviewed By: angelayi Differential Revision: D46995565 fbshipit-source-id: 3d735843b8db0fe2c003a74cb9895c68b6cdac51
1 parent 2fc4e7d commit 1819da5

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

exir/delegate.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,34 @@ def call_delegate_cpu(lowered_module, *args):
152152
def call_delegate_autograd(lowered_module, *args):
153153
# TODO: support autograd
154154
flat_operands, _ = tree_flatten([lowered_module, *args])
155-
assert all(
156-
[not f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)]
155+
requires_grad = any(
156+
[f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)]
157157
)
158158

159-
_ = torch._C.ExcludeDispatchKeyGuard(
159+
with torch._C._ExcludeDispatchKeyGuard(
160160
torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
161-
)
162-
return executorch_call_delegate(lowered_module, *args)
161+
):
162+
res = executorch_call_delegate(lowered_module, *args)
163+
164+
if requires_grad:
165+
err_fn = torch._C._functions.DelayedError(
166+
b"NYI: call_delegate doesn't support autograd",
167+
1,
168+
)
169+
# Create aliases of the output that has requires_grad=True. We need
170+
# at least one of the inputs to err_fn to require grad so that the
171+
# output will have a grad_fn.
172+
173+
# pyre-ignore
174+
def fake_requires_grad(var):
175+
if var is not None:
176+
var = var.detach()
177+
var.requires_grad = True
178+
return err_fn(var)
179+
180+
return pytree.tree_map(fake_requires_grad, res)
181+
182+
return res
163183

164184

165185
@executorch_call_delegate.py_impl(ProxyTorchDispatchMode)

0 commit comments

Comments
 (0)