Skip to content

Commit b21b165

Browse files
bdhirshfacebook-github-bot
authored andcommitted
internal fixes for FunctionalTensorMode usage in AOTAutograd (#538)
Summary: Fixes needed to properly land pytorch/pytorch#110079 internally (1) executorch has a higher order op that requires a functionalization rule (2) s-curve export still has an internal flow that calls some AOTAutograd API's, but also manually makes some calls to C++ funcitonalization. I changed them to use python functionalization. Reviewed By: zou3519, tugsbayasgalan Differential Revision: D49657241
1 parent 84b333d commit b21b165

File tree

1 file changed

+5
-36
lines changed

1 file changed

+5
-36
lines changed

exir/delegate.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -123,44 +123,13 @@ def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
123123
return lowered_module.original_module(*args)
124124

125125

126-
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Functionalize)
126+
@executorch_call_delegate.py_functionalize_impl
127127
# pyre-ignore
128-
def call_delegate_func(lowered_module, *args):
129-
reapply_views = torch._C._functionalization_reapply_views_tls()
130-
# At this point, we will see functionalized tensors, so need to unwrap them first
131-
unwrapped_args = tuple(
132-
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
133-
for arg in args
134-
)
135-
guard = torch._C.ExcludeDispatchKeyGuard(
136-
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
137-
)
138-
try:
139-
delegate_return = executorch_call_delegate(lowered_module, *unwrapped_args)
140-
return _wrap_all_tensors_to_functional(delegate_return, level=0)
141-
finally:
142-
del guard
143-
144-
145-
# pyre-ignore
146-
@executorch_call_delegate.py_impl(torch._C._functorch.TransformType.Functionalize)
147-
# pyre-ignore
148-
def call_delegate_functionalize(interpreter, lowered_module, *args):
149-
"""
150-
Functionalization implementation for torch.ops.executorch_call_delegate. We
151-
don't need to do anything since the delegated program is controlled by
152-
users.
153-
"""
154-
reapply_views = interpreter.functionalize_add_back_views()
155-
# At this point, we will see functionalized tensors, so need to unwrap them first
156-
unwrapped_args = tuple(
157-
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
158-
for arg in args
159-
)
160-
161-
with interpreter.lower():
128+
def call_delegate_functionalize(ctx, lowered_module, *args):
129+
unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
130+
with ctx.redispatch_to_next():
162131
res = executorch_call_delegate(lowered_module, *unwrapped_args)
163-
return _wrap_all_tensors_to_functional(res, level=interpreter.level())
132+
return ctx.wrap_tensors(res)
164133

165134

166135
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre

0 commit comments

Comments
 (0)