Skip to content

support emit sym value from delegate #3103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
TensorSpec,
)
from executorch.exir.types import LeafValueSpec, ValueSpec
from torch._subclasses.fake_tensor import FakeTensor

from torch.export.exported_program import ExportedProgram
from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -933,6 +934,35 @@ def _emit_argument(
return arg
return self._emit_evalue(self._constant_to_evalue(arg, arg_type))

def _get_sym_ret(
self,
val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
) -> Optional[_AbstractValue]:
"""
Returns the emit ret for sym value.
"""
ret = None
if isinstance(val, torch.SymInt):
ret = self._emit_evalue(EValue(Int(0)))
elif isinstance(val, torch.BoolType):
ret = self._emit_evalue(EValue(Bool(False)))
elif isinstance(val, torch.FloatType):
ret = self._emit_evalue(EValue(Double(0)))
return ret

def _get_sym_and_fake_tensor_ret(
self,
val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
spec: TensorSpec,
) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]:
# Try to get the ret if it's a sym value.
ret = self._get_sym_ret(val)
# If the ret is None, it means that the val is not a sym value, but a regular tensor
if ret is None:
ret = self._emit_spec(spec)
assert ret is not None, "Can't have a None ret"
return ret

def _emit_delegate(
self,
lowered_module: "LoweredBackendModule", # noqa
Expand All @@ -944,7 +974,40 @@ def _emit_delegate(
processed_bytes = lowered_module.processed_bytes

delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
delegate_ret = self._emit_spec(self.node.meta["spec"])
delegate_ret = None

if isinstance(self.node.meta["spec"], list):
delegate_ret = []
for index, _ in enumerate(self.node.meta["val"]):
ret = self._get_sym_and_fake_tensor_ret(
self.node.meta["val"][index], self.node.meta["spec"][index]
)
delegate_ret.append(ret)
elif isinstance(self.node.meta["spec"], tuple):
if isinstance(self.node.meta["val"], FakeTensor):
# There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
ret = self._get_sym_and_fake_tensor_ret(
self.node.meta["val"], self.node.meta["spec"][0]
)
delegate_ret = (ret,)
else:
delegate_ret = []
for index, _ in enumerate(self.node.meta["val"]):
ret = self._get_sym_and_fake_tensor_ret(
self.node.meta["val"][index], self.node.meta["spec"][index]
)
delegate_ret.append(ret)
delegate_ret = tuple(delegate_ret)
elif isinstance(self.node.meta["spec"], TensorSpec):
ret = self._get_sym_and_fake_tensor_ret(
self.node.meta["val"], self.node.meta["spec"]
)
delegate_ret = ret
else:
raise NotImplementedError(
f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported"
)
assert delegate_ret is not None, "Can't have a None delegate_ret"
if delegate_index is None:
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
# present.
Expand Down Expand Up @@ -1062,13 +1125,8 @@ def _get_empty_tensor_evalue() -> EValue:
torch.BoolType,
torch.NumberType,
), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}."
if type(target._schema.returns[0].type) == torch.IntType:
ret = self._emit_evalue(EValue(Int(0)))
elif type(target._schema.returns[0].type) == torch.BoolType:
ret = self._emit_evalue(EValue(Bool(False)))
elif type(target._schema.returns[0].type) == torch.FloatType:
ret = self._emit_evalue(EValue(Double(0)))
else: # type(target._schema.returns[0].type) == torch.NumberType:
ret = self._get_sym_ret(target._schema.returns[0])
if ret is None: # type(target._schema.returns[0].type) == torch.NumberType:
# Cant definitively say what type this is, the runtime operator just overrides the EValue completely
# though so we can just serialize whatever as a placeholder.
ret = self._emit_evalue(EValue(Int(0)))
Expand Down