87
87
TensorSpec ,
88
88
)
89
89
from executorch .exir .types import LeafValueSpec , ValueSpec
90
+ from torch ._subclasses .fake_tensor import FakeTensor
90
91
91
92
from torch .export .exported_program import ExportedProgram
92
93
from torch .utils import _pytree as pytree
@@ -933,6 +934,35 @@ def _emit_argument(
933
934
return arg
934
935
return self ._emit_evalue (self ._constant_to_evalue (arg , arg_type ))
935
936
937
+ def _get_sym_ret (
938
+ self ,
939
+ val : Tuple [Union [torch .SymInt , torch .BoolType , torch .FloatType , FakeTensor ]],
940
+ ) -> Optional [_AbstractValue ]:
941
+ """
942
+ Returns the emit ret for sym value.
943
+ """
944
+ ret = None
945
+ if isinstance (val , torch .SymInt ):
946
+ ret = self ._emit_evalue (EValue (Int (0 )))
947
+ elif isinstance (val , torch .BoolType ):
948
+ ret = self ._emit_evalue (EValue (Bool (False )))
949
+ elif isinstance (val , torch .FloatType ):
950
+ ret = self ._emit_evalue (EValue (Double (0 )))
951
+ return ret
952
+
953
+ def _get_sym_and_fake_tensor_ret (
954
+ self ,
955
+ val : Tuple [Union [torch .SymInt , torch .BoolType , torch .FloatType , FakeTensor ]],
956
+ spec : TensorSpec ,
957
+ ) -> Union [List [_AbstractValue ], _AbstractValue , Tuple [_AbstractValue , ...]]:
958
+ # Try to get the ret if it's a sym value.
959
+ ret = self ._get_sym_ret (val )
960
+ # If the ret is None, it means that the val is not a sym value, but a regular tensor
961
+ if ret is None :
962
+ ret = self ._emit_spec (spec )
963
+ assert ret is not None , "Can't have a None ret"
964
+ return ret
965
+
936
966
def _emit_delegate (
937
967
self ,
938
968
lowered_module : "LoweredBackendModule" , # noqa
@@ -944,7 +974,40 @@ def _emit_delegate(
944
974
processed_bytes = lowered_module .processed_bytes
945
975
946
976
delegate_index = self .emitter_state .delegate_cache .get (processed_bytes )
947
- delegate_ret = self ._emit_spec (self .node .meta ["spec" ])
977
+ delegate_ret = None
978
+
979
+ if isinstance (self .node .meta ["spec" ], list ):
980
+ delegate_ret = []
981
+ for index , _ in enumerate (self .node .meta ["val" ]):
982
+ ret = self ._get_sym_and_fake_tensor_ret (
983
+ self .node .meta ["val" ][index ], self .node .meta ["spec" ][index ]
984
+ )
985
+ delegate_ret .append (ret )
986
+ elif isinstance (self .node .meta ["spec" ], tuple ):
987
+ if isinstance (self .node .meta ["val" ], FakeTensor ):
988
+ # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
989
+ ret = self ._get_sym_and_fake_tensor_ret (
990
+ self .node .meta ["val" ], self .node .meta ["spec" ][0 ]
991
+ )
992
+ delegate_ret = (ret ,)
993
+ else :
994
+ delegate_ret = []
995
+ for index , _ in enumerate (self .node .meta ["val" ]):
996
+ ret = self ._get_sym_and_fake_tensor_ret (
997
+ self .node .meta ["val" ][index ], self .node .meta ["spec" ][index ]
998
+ )
999
+ delegate_ret .append (ret )
1000
+ delegate_ret = tuple (delegate_ret )
1001
+ elif isinstance (self .node .meta ["spec" ], TensorSpec ):
1002
+ ret = self ._get_sym_and_fake_tensor_ret (
1003
+ self .node .meta ["val" ], self .node .meta ["spec" ]
1004
+ )
1005
+ delegate_ret = ret
1006
+ else :
1007
+ raise NotImplementedError (
1008
+ f"self.node.meta['spec'] { type (self .node .meta ['spec' ])} is not supported"
1009
+ )
1010
+ assert delegate_ret is not None , "Can't have a None delegate_ret"
948
1011
if delegate_index is None :
949
1012
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
950
1013
# present.
@@ -1062,13 +1125,8 @@ def _get_empty_tensor_evalue() -> EValue:
1062
1125
torch .BoolType ,
1063
1126
torch .NumberType ,
1064
1127
), f"Only symbolic ops that return a Int Bool Float are supported currently got { type (target ._schema .returns [0 ].type )} ."
1065
- if type (target ._schema .returns [0 ].type ) == torch .IntType :
1066
- ret = self ._emit_evalue (EValue (Int (0 )))
1067
- elif type (target ._schema .returns [0 ].type ) == torch .BoolType :
1068
- ret = self ._emit_evalue (EValue (Bool (False )))
1069
- elif type (target ._schema .returns [0 ].type ) == torch .FloatType :
1070
- ret = self ._emit_evalue (EValue (Double (0 )))
1071
- else : # type(target._schema.returns[0].type) == torch.NumberType:
1128
+ ret = self ._get_sym_ret (target ._schema .returns [0 ])
1129
+ if ret is None : # type(target._schema.returns[0].type) == torch.NumberType:
1072
1130
# Cant definitively say what type this is, the runtime operator just overrides the EValue completely
1073
1131
# though so we can just serialize whatever as a placeholder.
1074
1132
ret = self ._emit_evalue (EValue (Int (0 )))
0 commit comments