Skip to content

Commit 73599f4

Browse files
cccclaifacebook-github-bot
authored andcommitted
support emit sym value from delegate (#3103)
Summary: Pull Request resolved: #3103 For dynamic shape, if delegate output is dynamic shape, the return might be something like `(s0, x, y)`, and `s0` is a sym type while others are fake tensor. In this case, we will emit the sym value (including `SymFloat`, `SymBool`, `SymInt`) to a unique Evalue. Since the sym type node will have an empty spec, we use the `node.meta['val']` to find out it's a sym type node. Reviewed By: mcr229 Differential Revision: D56176100 fbshipit-source-id: a4ddc7225ed014c59ceb9fa8ba4a9cb394af00e5
1 parent ebc38b2 commit 73599f4

File tree

1 file changed

+66
-8
lines changed

1 file changed

+66
-8
lines changed

exir/emit/_emitter.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
TensorSpec,
8888
)
8989
from executorch.exir.types import LeafValueSpec, ValueSpec
90+
from torch._subclasses.fake_tensor import FakeTensor
9091

9192
from torch.export.exported_program import ExportedProgram
9293
from torch.utils import _pytree as pytree
@@ -933,6 +934,35 @@ def _emit_argument(
933934
return arg
934935
return self._emit_evalue(self._constant_to_evalue(arg, arg_type))
935936

937+
def _get_sym_ret(
938+
self,
939+
val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
940+
) -> Optional[_AbstractValue]:
941+
"""
942+
Returns the emit ret for sym value.
943+
"""
944+
ret = None
945+
if isinstance(val, torch.SymInt):
946+
ret = self._emit_evalue(EValue(Int(0)))
947+
elif isinstance(val, torch.BoolType):
948+
ret = self._emit_evalue(EValue(Bool(False)))
949+
elif isinstance(val, torch.FloatType):
950+
ret = self._emit_evalue(EValue(Double(0)))
951+
return ret
952+
953+
def _get_sym_and_fake_tensor_ret(
954+
self,
955+
val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
956+
spec: TensorSpec,
957+
) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]:
958+
# Try to get the ret if it's a sym value.
959+
ret = self._get_sym_ret(val)
960+
# If the ret is None, it means that the val is not a sym value, but a regular tensor
961+
if ret is None:
962+
ret = self._emit_spec(spec)
963+
assert ret is not None, "Can't have a None ret"
964+
return ret
965+
936966
def _emit_delegate(
937967
self,
938968
lowered_module: "LoweredBackendModule", # noqa
@@ -944,7 +974,40 @@ def _emit_delegate(
944974
processed_bytes = lowered_module.processed_bytes
945975

946976
delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
947-
delegate_ret = self._emit_spec(self.node.meta["spec"])
977+
delegate_ret = None
978+
979+
if isinstance(self.node.meta["spec"], list):
980+
delegate_ret = []
981+
for index, _ in enumerate(self.node.meta["val"]):
982+
ret = self._get_sym_and_fake_tensor_ret(
983+
self.node.meta["val"][index], self.node.meta["spec"][index]
984+
)
985+
delegate_ret.append(ret)
986+
elif isinstance(self.node.meta["spec"], tuple):
987+
if isinstance(self.node.meta["val"], FakeTensor):
988+
# There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
989+
ret = self._get_sym_and_fake_tensor_ret(
990+
self.node.meta["val"], self.node.meta["spec"][0]
991+
)
992+
delegate_ret = (ret,)
993+
else:
994+
delegate_ret = []
995+
for index, _ in enumerate(self.node.meta["val"]):
996+
ret = self._get_sym_and_fake_tensor_ret(
997+
self.node.meta["val"][index], self.node.meta["spec"][index]
998+
)
999+
delegate_ret.append(ret)
1000+
delegate_ret = tuple(delegate_ret)
1001+
elif isinstance(self.node.meta["spec"], TensorSpec):
1002+
ret = self._get_sym_and_fake_tensor_ret(
1003+
self.node.meta["val"], self.node.meta["spec"]
1004+
)
1005+
delegate_ret = ret
1006+
else:
1007+
raise NotImplementedError(
1008+
f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported"
1009+
)
1010+
assert delegate_ret is not None, "Can't have a None delegate_ret"
9481011
if delegate_index is None:
9491012
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
9501013
# present.
@@ -1062,13 +1125,8 @@ def _get_empty_tensor_evalue() -> EValue:
10621125
torch.BoolType,
10631126
torch.NumberType,
10641127
), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}."
1065-
if type(target._schema.returns[0].type) == torch.IntType:
1066-
ret = self._emit_evalue(EValue(Int(0)))
1067-
elif type(target._schema.returns[0].type) == torch.BoolType:
1068-
ret = self._emit_evalue(EValue(Bool(False)))
1069-
elif type(target._schema.returns[0].type) == torch.FloatType:
1070-
ret = self._emit_evalue(EValue(Double(0)))
1071-
else: # type(target._schema.returns[0].type) == torch.NumberType:
1128+
ret = self._get_sym_ret(target._schema.returns[0])
1129+
if ret is None: # type(target._schema.returns[0].type) == torch.NumberType:
10721130
# Cant definitively say what type this is, the runtime operator just overrides the EValue completely
10731131
# though so we can just serialize whatever as a placeholder.
10741132
ret = self._emit_evalue(EValue(Int(0)))

0 commit comments

Comments
 (0)