Skip to content

Commit 3acbbb1

Browse files
yanboliangfacebook-github-bot
authored andcommitted
Should automatically pop modes
Summary: Fixes pytorch/pytorch#108282 X-link: pytorch/pytorch#109157 Reviewed By: zou3519 Differential Revision: D49359181 Pulled By: yanboliang fbshipit-source-id: f91d4aec67db980f3c66b7e0232106b89e34aef3
1 parent 0f3d42f commit 3acbbb1

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

exir/delegate.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,16 @@ def fake_requires_grad(var):
114114

115115
@executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
116116
# pyre-ignore
117-
def call_delegate_proxy_torch_dispatch_mode(lowered_module, *args):
118-
mode = _get_current_dispatch_mode()
119-
assert mode is not None, "Mode should always be enabled for python fallback key"
120-
with _pop_mode_temporarily() as mode:
121-
res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
117+
def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
118+
res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
122119
return res
123120

124121

125122
@executorch_call_delegate.py_impl(FakeTensorMode)
126123
# pyre-ignore
127-
def call_delegate_fake_tensor_mode(lowered_module, *args):
128-
return lowered_module.original_module(*args)
124+
def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
125+
with mode:
126+
return lowered_module.original_module(*args)
129127

130128

131129
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Functionalize)

0 commit comments

Comments
 (0)