@@ -188,6 +188,67 @@ class AbstractionPattern {
188
188
// / The partially-applied curried imported type of a C++ method. OrigType is
189
189
// / valid and is a function type. CXXMethod is valid.
190
190
PartialCurriedCXXMethodType,
191
+ // SWIFT_ENABLE_TENSORFLOW
192
+ // / A Swift function whose parameters and results are opaque. This is
193
+ // / like `AP::Type<T>((T) -> T)`, except that the number of parameters is
194
+ // / unspecified.
195
+ // /
196
+ // / This is used to construct the abstraction pattern for the
197
+ // / derivative function of a function with opaque abstraction pattern. See
198
+ // / `OpaqueDerivativeFunction`.
199
+ OpaqueFunction,
200
+ // / A Swift function whose parameters are opaque and whose result is the
201
+ // / tuple abstraction pattern `(AP::Opaque, AP::OpaqueFunction)`.
202
+ // /
203
+ // / Purpose: when we reabstract `@differentiable` function-typed values
204
+ // / using the`AP::Opaque` pattern, we use `AP::Opaque` to reabstract the
205
+ // / original function in the bundle and `AP::OpaqueDerivativeFunction` to
206
+ // / reabstract the derivative functions in the bundle. This preserves the
207
+ // / `@differentiable` function invariant that the derivative type
208
+ // / (`SILFunctionType::getAutoDiffDerivativeFunctionType()`) of the original
209
+ // / function is equal to the type of the derivative function. For example:
210
+ // /
211
+ // / differentiable_function
212
+ // / [parameters 0]
213
+ // / %0 : $@callee_guaranteed (Float) -> Float
214
+ // / with_derivative {
215
+ // / %1 : $@callee_guaranteed (Float) -> (
216
+ // / Float,
217
+ // / @owned @callee_guaranteed (Float) -> Float
218
+ // / ),
219
+ // / %2 : $@callee_guaranteed (Float) -> (
220
+ // / Float,
221
+ // / @owned @callee_guaranteed (Float) -> Float
222
+ // / )
223
+ // / }
224
+ // /
225
+ // / The invariant-respecting abstraction of this value to `AP::Opaque` is:
226
+ // /
227
+ // / differentiable_function
228
+ // / [parameters 0]
229
+ // / %3 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
230
+ // / with_derivative {
231
+ // / %4 : $@callee_guaranteed (@in_guaranteed Float) -> (
232
+ // / @out Float,
233
+ // / @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
234
+ // / ),
235
+ // / %5 : $@callee_guaranteed (@in_guaranteed Float) -> (
236
+ // / @out Float,
237
+ // / @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
238
+ // / )
239
+ // / }
240
+ // /
241
+ // / In particular:
242
+ // /
243
+ // / - The reabstraction %0 => %3 uses pattern `AP::Opaque`.
244
+ // / - The reabstraction %1 => %4 uses pattern
245
+ // / `AP::OpaqueDerivativeFunction`, which maximally abstracts all the
246
+ // / parameters, and abstracts the result as the tuple
247
+ // / `(AP::Opaque, AP::OpaqueFunction)`.
248
+ // / - The reabstraction %2 => %5 similarly uses pattern
249
+ // / `AP::OpaqueDerivativeFunction`.
250
+ OpaqueDerivativeFunction,
251
+ // SWIFT_ENABLE_TENSORFLOW END
191
252
};
192
253
193
254
class EncodedForeignErrorInfo {
@@ -238,7 +299,9 @@ class AbstractionPattern {
238
299
static constexpr const unsigned NumOtherDataBits = 28 ;
239
300
static constexpr const unsigned MaxOtherData = (1 << NumOtherDataBits) - 1 ;
240
301
241
- unsigned TheKind : 32 - NumOtherDataBits;
302
+ // SWIFT_ENABLE_TENSORFLOW
303
+ unsigned TheKind : 33 - NumOtherDataBits;
304
+ // SWIFT_ENABLE_TENSORFLOW END
242
305
unsigned OtherData : NumOtherDataBits;
243
306
CanType OrigType;
244
307
union {
@@ -382,6 +445,16 @@ class AbstractionPattern {
382
445
return AbstractionPattern (Kind::Invalid);
383
446
}
384
447
448
+ // SWIFT_ENABLE_TENSORFLOW
449
+ static AbstractionPattern getOpaqueFunction () {
450
+ return AbstractionPattern (Kind::OpaqueFunction);
451
+ }
452
+
453
+ static AbstractionPattern getOpaqueDerivativeFunction () {
454
+ return AbstractionPattern (Kind::OpaqueDerivativeFunction);
455
+ }
456
+ // SWIFT_ENABLE_TENSORFLOW END
457
+
385
458
bool hasGenericSignature () const {
386
459
switch (getKind ()) {
387
460
case Kind::Type:
@@ -400,6 +473,10 @@ class AbstractionPattern {
400
473
case Kind::Invalid:
401
474
case Kind::Opaque:
402
475
case Kind::Tuple:
476
+ // SWIFT_ENABLE_TENSORFLOW
477
+ case Kind::OpaqueFunction:
478
+ case Kind::OpaqueDerivativeFunction:
479
+ // SWIFT_ENABLE_TENSORFLOW END
403
480
return false ;
404
481
}
405
482
llvm_unreachable (" Unhandled AbstractionPatternKind in switch" );
@@ -728,6 +805,12 @@ class AbstractionPattern {
728
805
llvm_unreachable (" opaque pattern has no type" );
729
806
case Kind::Tuple:
730
807
llvm_unreachable (" open-coded tuple pattern has no type" );
808
+ // SWIFT_ENABLE_TENSORFLOW
809
+ case Kind::OpaqueFunction:
810
+ llvm_unreachable (" opaque function pattern has no type" );
811
+ case Kind::OpaqueDerivativeFunction:
812
+ llvm_unreachable (" opaque derivative function pattern has no type" );
813
+ // SWIFT_ENABLE_TENSORFLOW END
731
814
case Kind::ClangType:
732
815
case Kind::CurriedObjCMethodType:
733
816
case Kind::PartialCurriedObjCMethodType:
@@ -761,6 +844,10 @@ class AbstractionPattern {
761
844
case Kind::Invalid:
762
845
case Kind::Opaque:
763
846
case Kind::Tuple:
847
+ // SWIFT_ENABLE_TENSORFLOW
848
+ case Kind::OpaqueFunction:
849
+ case Kind::OpaqueDerivativeFunction:
850
+ // SWIFT_ENABLE_TENSORFLOW END
764
851
llvm_unreachable (" type cannot be replaced on pattern without type" );
765
852
case Kind::ClangType:
766
853
case Kind::CurriedObjCMethodType:
@@ -796,6 +883,10 @@ class AbstractionPattern {
796
883
case Kind::Tuple:
797
884
case Kind::Type:
798
885
case Kind::Discard:
886
+ // SWIFT_ENABLE_TENSORFLOW
887
+ case Kind::OpaqueFunction:
888
+ case Kind::OpaqueDerivativeFunction:
889
+ // SWIFT_ENABLE_TENSORFLOW END
799
890
return false ;
800
891
case Kind::ClangType:
801
892
case Kind::PartialCurriedObjCMethodType:
@@ -852,6 +943,13 @@ class AbstractionPattern {
852
943
return CXXMethod;
853
944
}
854
945
946
+ // SWIFT_ENABLE_TENSORFLOW
947
+ bool isOpaqueFunctionOrOpaqueDerivativeFunction () const {
948
+ return (getKind () == Kind::OpaqueFunction ||
949
+ getKind () == Kind::OpaqueDerivativeFunction);
950
+ }
951
+ // SWIFT_ENABLE_TENSORFLOW END
952
+
855
953
EncodedForeignErrorInfo getEncodedForeignErrorInfo () const {
856
954
assert (hasStoredForeignErrorInfo ());
857
955
return EncodedForeignErrorInfo::fromOpaqueValue (OtherData);
@@ -874,6 +972,10 @@ class AbstractionPattern {
874
972
case Kind::CXXMethodType:
875
973
case Kind::CurriedCXXMethodType:
876
974
case Kind::PartialCurriedCXXMethodType:
975
+ // SWIFT_ENABLE_TENSORFLOW
976
+ case Kind::OpaqueFunction:
977
+ case Kind::OpaqueDerivativeFunction:
978
+ // SWIFT_ENABLE_TENSORFLOW END
877
979
return false ;
878
980
case Kind::PartialCurriedObjCMethodType:
879
981
case Kind::CurriedObjCMethodType:
@@ -895,6 +997,11 @@ class AbstractionPattern {
895
997
return typename CanTypeWrapperTraits<TYPE>::type ();
896
998
case Kind::Tuple:
897
999
return typename CanTypeWrapperTraits<TYPE>::type ();
1000
+ // SWIFT_ENABLE_TENSORFLOW
1001
+ case Kind::OpaqueFunction:
1002
+ case Kind::OpaqueDerivativeFunction:
1003
+ // SWIFT_ENABLE_TENSORFLOW END
1004
+ return typename CanTypeWrapperTraits<TYPE>::type ();
898
1005
case Kind::ClangType:
899
1006
case Kind::PartialCurriedObjCMethodType:
900
1007
case Kind::CurriedObjCMethodType:
@@ -933,6 +1040,10 @@ class AbstractionPattern {
933
1040
case Kind::CXXMethodType:
934
1041
case Kind::CurriedCXXMethodType:
935
1042
case Kind::PartialCurriedCXXMethodType:
1043
+ // SWIFT_ENABLE_TENSORFLOW
1044
+ case Kind::OpaqueFunction:
1045
+ case Kind::OpaqueDerivativeFunction:
1046
+ // SWIFT_ENABLE_TENSORFLOW END
936
1047
// We assume that the Clang type might provide additional structure.
937
1048
return false ;
938
1049
case Kind::Type:
@@ -960,6 +1071,10 @@ class AbstractionPattern {
960
1071
case Kind::CXXMethodType:
961
1072
case Kind::CurriedCXXMethodType:
962
1073
case Kind::PartialCurriedCXXMethodType:
1074
+ // SWIFT_ENABLE_TENSORFLOW
1075
+ case Kind::OpaqueFunction:
1076
+ case Kind::OpaqueDerivativeFunction:
1077
+ // SWIFT_ENABLE_TENSORFLOW END
963
1078
return false ;
964
1079
case Kind::Tuple:
965
1080
return true ;
@@ -985,6 +1100,10 @@ class AbstractionPattern {
985
1100
case Kind::CXXMethodType:
986
1101
case Kind::CurriedCXXMethodType:
987
1102
case Kind::PartialCurriedCXXMethodType:
1103
+ // SWIFT_ENABLE_TENSORFLOW
1104
+ case Kind::OpaqueFunction:
1105
+ case Kind::OpaqueDerivativeFunction:
1106
+ // SWIFT_ENABLE_TENSORFLOW END
988
1107
llvm_unreachable (" pattern is not a tuple" );
989
1108
case Kind::Tuple:
990
1109
return getNumTupleElements_Stored ();
@@ -1020,6 +1139,20 @@ class AbstractionPattern {
1020
1139
// / it.
1021
1140
AbstractionPattern getReferenceStorageReferentType () const ;
1022
1141
1142
+ // SWIFT_ENABLE_TENSORFLOW
1143
+ // / Given that the value being abstracted is a function type, return the
1144
+ // / abstraction pattern for the derivative function.
1145
+ // /
1146
+ // / The arguments are the same as the arguments to
1147
+ // / `AnyFunctionType::getAutoDiffDerivativeFunctionType()`.
1148
+ AbstractionPattern getAutoDiffDerivativeFunctionType (
1149
+ IndexSubset *indices, unsigned resultIndex,
1150
+ AutoDiffDerivativeFunctionKind kind,
1151
+ LookupConformanceFn lookupConformance,
1152
+ GenericSignature derivativeGenericSignature = GenericSignature(),
1153
+ bool makeSelfParamFirst = false);
1154
+ // SWIFT_ENABLE_TENSORFLOW END
1155
+
1023
1156
void dump () const LLVM_ATTRIBUTE_USED;
1024
1157
void print (raw_ostream &OS) const ;
1025
1158
};
0 commit comments