Skip to content

Commit 534d72c

Browse files
author
marcrasi
authored
[AutoDiff] TF-123: fix reabstraction to opaque (#29394)
Add new abstraction patterns to support properly abstracting `@differentiable` function values to `AP::Opaque`: - `AP::OpaqueFunction` - `AP::OpaqueDerivativeFunction` Comments in `AbstractionPattern.h` describe the solution in more detail. Resolves TF-123. Even though TF-123 is technically fixed (all its test cases compile and run successfully), many simple operations involving opaque `@differentiable` function values are still broken due to a separate issue TF-1122. Reabstraction tests are also added: - The last 3 tests in `reabstraction_e2e.swift` are the ones that have been fixed by this PR. - The last 3 tests include some commented-out code that currently segfaults at runtime due to TF-1122. - All the tests in `reabstraction_e2e.swift` except for the last 3 already pass before this PR. I just used this PR as an opportunity to test reabstraction more thoroughly.
1 parent 9ee45c4 commit 534d72c

File tree

8 files changed

+410
-30
lines changed

8 files changed

+410
-30
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,67 @@ class AbstractionPattern {
188188
/// The partially-applied curried imported type of a C++ method. OrigType is
189189
/// valid and is a function type. CXXMethod is valid.
190190
PartialCurriedCXXMethodType,
191+
// SWIFT_ENABLE_TENSORFLOW
192+
/// A Swift function whose parameters and results are opaque. This is
193+
/// like `AP::Type<T>((T) -> T)`, except that the number of parameters is
194+
/// unspecified.
195+
///
196+
/// This is used to construct the abstraction pattern for the
197+
/// derivative function of a function with opaque abstraction pattern. See
198+
/// `OpaqueDerivativeFunction`.
199+
OpaqueFunction,
200+
/// A Swift function whose parameters are opaque and whose result is the
201+
/// tuple abstraction pattern `(AP::Opaque, AP::OpaqueFunction)`.
202+
///
203+
/// Purpose: when we reabstract `@differentiable` function-typed values
204+
/// using the`AP::Opaque` pattern, we use `AP::Opaque` to reabstract the
205+
/// original function in the bundle and `AP::OpaqueDerivativeFunction` to
206+
/// reabstract the derivative functions in the bundle. This preserves the
207+
/// `@differentiable` function invariant that the derivative type
208+
/// (`SILFunctionType::getAutoDiffDerivativeFunctionType()`) of the original
209+
/// function is equal to the type of the derivative function. For example:
210+
///
211+
/// differentiable_function
212+
/// [parameters 0]
213+
/// %0 : $@callee_guaranteed (Float) -> Float
214+
/// with_derivative {
215+
/// %1 : $@callee_guaranteed (Float) -> (
216+
/// Float,
217+
/// @owned @callee_guaranteed (Float) -> Float
218+
/// ),
219+
/// %2 : $@callee_guaranteed (Float) -> (
220+
/// Float,
221+
/// @owned @callee_guaranteed (Float) -> Float
222+
/// )
223+
/// }
224+
///
225+
/// The invariant-respecting abstraction of this value to `AP::Opaque` is:
226+
///
227+
/// differentiable_function
228+
/// [parameters 0]
229+
/// %3 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
230+
/// with_derivative {
231+
/// %4 : $@callee_guaranteed (@in_guaranteed Float) -> (
232+
/// @out Float,
233+
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
234+
/// ),
235+
/// %5 : $@callee_guaranteed (@in_guaranteed Float) -> (
236+
/// @out Float,
237+
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
238+
/// )
239+
/// }
240+
///
241+
/// In particular:
242+
///
243+
/// - The reabstraction %0 => %3 uses pattern `AP::Opaque`.
244+
/// - The reabstraction %1 => %4 uses pattern
245+
/// `AP::OpaqueDerivativeFunction`, which maximally abstracts all the
246+
/// parameters, and abstracts the result as the tuple
247+
/// `(AP::Opaque, AP::OpaqueFunction)`.
248+
/// - The reabstraction %2 => %5 similarly uses pattern
249+
/// `AP::OpaqueDerivativeFunction`.
250+
OpaqueDerivativeFunction,
251+
// SWIFT_ENABLE_TENSORFLOW END
191252
};
192253

193254
class EncodedForeignErrorInfo {
@@ -238,7 +299,9 @@ class AbstractionPattern {
238299
static constexpr const unsigned NumOtherDataBits = 28;
239300
static constexpr const unsigned MaxOtherData = (1 << NumOtherDataBits) - 1;
240301

241-
unsigned TheKind : 32 - NumOtherDataBits;
302+
// SWIFT_ENABLE_TENSORFLOW
303+
unsigned TheKind : 33 - NumOtherDataBits;
304+
// SWIFT_ENABLE_TENSORFLOW END
242305
unsigned OtherData : NumOtherDataBits;
243306
CanType OrigType;
244307
union {
@@ -382,6 +445,16 @@ class AbstractionPattern {
382445
return AbstractionPattern(Kind::Invalid);
383446
}
384447

448+
// SWIFT_ENABLE_TENSORFLOW
449+
static AbstractionPattern getOpaqueFunction() {
450+
return AbstractionPattern(Kind::OpaqueFunction);
451+
}
452+
453+
static AbstractionPattern getOpaqueDerivativeFunction() {
454+
return AbstractionPattern(Kind::OpaqueDerivativeFunction);
455+
}
456+
// SWIFT_ENABLE_TENSORFLOW END
457+
385458
bool hasGenericSignature() const {
386459
switch (getKind()) {
387460
case Kind::Type:
@@ -400,6 +473,10 @@ class AbstractionPattern {
400473
case Kind::Invalid:
401474
case Kind::Opaque:
402475
case Kind::Tuple:
476+
// SWIFT_ENABLE_TENSORFLOW
477+
case Kind::OpaqueFunction:
478+
case Kind::OpaqueDerivativeFunction:
479+
// SWIFT_ENABLE_TENSORFLOW END
403480
return false;
404481
}
405482
llvm_unreachable("Unhandled AbstractionPatternKind in switch");
@@ -728,6 +805,12 @@ class AbstractionPattern {
728805
llvm_unreachable("opaque pattern has no type");
729806
case Kind::Tuple:
730807
llvm_unreachable("open-coded tuple pattern has no type");
808+
// SWIFT_ENABLE_TENSORFLOW
809+
case Kind::OpaqueFunction:
810+
llvm_unreachable("opaque function pattern has no type");
811+
case Kind::OpaqueDerivativeFunction:
812+
llvm_unreachable("opaque derivative function pattern has no type");
813+
// SWIFT_ENABLE_TENSORFLOW END
731814
case Kind::ClangType:
732815
case Kind::CurriedObjCMethodType:
733816
case Kind::PartialCurriedObjCMethodType:
@@ -761,6 +844,10 @@ class AbstractionPattern {
761844
case Kind::Invalid:
762845
case Kind::Opaque:
763846
case Kind::Tuple:
847+
// SWIFT_ENABLE_TENSORFLOW
848+
case Kind::OpaqueFunction:
849+
case Kind::OpaqueDerivativeFunction:
850+
// SWIFT_ENABLE_TENSORFLOW END
764851
llvm_unreachable("type cannot be replaced on pattern without type");
765852
case Kind::ClangType:
766853
case Kind::CurriedObjCMethodType:
@@ -796,6 +883,10 @@ class AbstractionPattern {
796883
case Kind::Tuple:
797884
case Kind::Type:
798885
case Kind::Discard:
886+
// SWIFT_ENABLE_TENSORFLOW
887+
case Kind::OpaqueFunction:
888+
case Kind::OpaqueDerivativeFunction:
889+
// SWIFT_ENABLE_TENSORFLOW END
799890
return false;
800891
case Kind::ClangType:
801892
case Kind::PartialCurriedObjCMethodType:
@@ -852,6 +943,13 @@ class AbstractionPattern {
852943
return CXXMethod;
853944
}
854945

946+
// SWIFT_ENABLE_TENSORFLOW
947+
bool isOpaqueFunctionOrOpaqueDerivativeFunction() const {
948+
return (getKind() == Kind::OpaqueFunction ||
949+
getKind() == Kind::OpaqueDerivativeFunction);
950+
}
951+
// SWIFT_ENABLE_TENSORFLOW END
952+
855953
EncodedForeignErrorInfo getEncodedForeignErrorInfo() const {
856954
assert(hasStoredForeignErrorInfo());
857955
return EncodedForeignErrorInfo::fromOpaqueValue(OtherData);
@@ -874,6 +972,10 @@ class AbstractionPattern {
874972
case Kind::CXXMethodType:
875973
case Kind::CurriedCXXMethodType:
876974
case Kind::PartialCurriedCXXMethodType:
975+
// SWIFT_ENABLE_TENSORFLOW
976+
case Kind::OpaqueFunction:
977+
case Kind::OpaqueDerivativeFunction:
978+
// SWIFT_ENABLE_TENSORFLOW END
877979
return false;
878980
case Kind::PartialCurriedObjCMethodType:
879981
case Kind::CurriedObjCMethodType:
@@ -895,6 +997,11 @@ class AbstractionPattern {
895997
return typename CanTypeWrapperTraits<TYPE>::type();
896998
case Kind::Tuple:
897999
return typename CanTypeWrapperTraits<TYPE>::type();
1000+
// SWIFT_ENABLE_TENSORFLOW
1001+
case Kind::OpaqueFunction:
1002+
case Kind::OpaqueDerivativeFunction:
1003+
// SWIFT_ENABLE_TENSORFLOW END
1004+
return typename CanTypeWrapperTraits<TYPE>::type();
8981005
case Kind::ClangType:
8991006
case Kind::PartialCurriedObjCMethodType:
9001007
case Kind::CurriedObjCMethodType:
@@ -933,6 +1040,10 @@ class AbstractionPattern {
9331040
case Kind::CXXMethodType:
9341041
case Kind::CurriedCXXMethodType:
9351042
case Kind::PartialCurriedCXXMethodType:
1043+
// SWIFT_ENABLE_TENSORFLOW
1044+
case Kind::OpaqueFunction:
1045+
case Kind::OpaqueDerivativeFunction:
1046+
// SWIFT_ENABLE_TENSORFLOW END
9361047
// We assume that the Clang type might provide additional structure.
9371048
return false;
9381049
case Kind::Type:
@@ -960,6 +1071,10 @@ class AbstractionPattern {
9601071
case Kind::CXXMethodType:
9611072
case Kind::CurriedCXXMethodType:
9621073
case Kind::PartialCurriedCXXMethodType:
1074+
// SWIFT_ENABLE_TENSORFLOW
1075+
case Kind::OpaqueFunction:
1076+
case Kind::OpaqueDerivativeFunction:
1077+
// SWIFT_ENABLE_TENSORFLOW END
9631078
return false;
9641079
case Kind::Tuple:
9651080
return true;
@@ -985,6 +1100,10 @@ class AbstractionPattern {
9851100
case Kind::CXXMethodType:
9861101
case Kind::CurriedCXXMethodType:
9871102
case Kind::PartialCurriedCXXMethodType:
1103+
// SWIFT_ENABLE_TENSORFLOW
1104+
case Kind::OpaqueFunction:
1105+
case Kind::OpaqueDerivativeFunction:
1106+
// SWIFT_ENABLE_TENSORFLOW END
9881107
llvm_unreachable("pattern is not a tuple");
9891108
case Kind::Tuple:
9901109
return getNumTupleElements_Stored();
@@ -1020,6 +1139,20 @@ class AbstractionPattern {
10201139
/// it.
10211140
AbstractionPattern getReferenceStorageReferentType() const;
10221141

1142+
// SWIFT_ENABLE_TENSORFLOW
1143+
/// Given that the value being abstracted is a function type, return the
1144+
/// abstraction pattern for the derivative function.
1145+
///
1146+
/// The arguments are the same as the arguments to
1147+
/// `AnyFunctionType::getAutoDiffDerivativeFunctionType()`.
1148+
AbstractionPattern getAutoDiffDerivativeFunctionType(
1149+
IndexSubset *indices, unsigned resultIndex,
1150+
AutoDiffDerivativeFunctionKind kind,
1151+
LookupConformanceFn lookupConformance,
1152+
GenericSignature derivativeGenericSignature = GenericSignature(),
1153+
bool makeSelfParamFirst = false);
1154+
// SWIFT_ENABLE_TENSORFLOW END
1155+
10231156
void dump() const LLVM_ATTRIBUTE_USED;
10241157
void print(raw_ostream &OS) const;
10251158
};

lib/SIL/AbstractionPattern.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ AbstractionPattern::getOptional(AbstractionPattern object) {
176176
case Kind::CXXMethodType:
177177
case Kind::CurriedCXXMethodType:
178178
case Kind::PartialCurriedCXXMethodType:
179+
// SWIFT_ENABLE_TENSORFLOW
180+
case Kind::OpaqueFunction:
181+
case Kind::OpaqueDerivativeFunction:
182+
// SWIFT_ENABLE_TENSORFLOW END
179183
llvm_unreachable("cannot add optionality to non-type abstraction");
180184
case Kind::Opaque:
181185
return AbstractionPattern::getOpaque();
@@ -267,6 +271,10 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) {
267271
case Kind::CXXMethodType:
268272
case Kind::CurriedCXXMethodType:
269273
case Kind::PartialCurriedCXXMethodType:
274+
// SWIFT_ENABLE_TENSORFLOW
275+
case Kind::OpaqueFunction:
276+
case Kind::OpaqueDerivativeFunction:
277+
// SWIFT_ENABLE_TENSORFLOW END
270278
return false;
271279
case Kind::Opaque:
272280
return true;
@@ -332,6 +340,10 @@ AbstractionPattern::getTupleElementType(unsigned index) const {
332340
case Kind::CXXMethodType:
333341
case Kind::CurriedCXXMethodType:
334342
case Kind::PartialCurriedCXXMethodType:
343+
// SWIFT_ENABLE_TENSORFLOW
344+
case Kind::OpaqueFunction:
345+
case Kind::OpaqueDerivativeFunction:
346+
// SWIFT_ENABLE_TENSORFLOW END
335347
llvm_unreachable("function types are not tuples");
336348
case Kind::Opaque:
337349
return *this;
@@ -459,7 +471,15 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
459471
return AbstractionPattern(getGenericSignatureForFunctionComponent(),
460472
getResultType(getType()),
461473
getObjCMethod()->getReturnType().getTypePtr());
462-
}
474+
// SWIFT_ENABLE_TENSORFLOW
475+
case Kind::OpaqueFunction:
476+
return getOpaque();
477+
case Kind::OpaqueDerivativeFunction:
478+
static SmallVector<AbstractionPattern, 2> elements{getOpaque(),
479+
getOpaqueFunction()};
480+
return getTuple(elements);
481+
}
482+
// SWIFT_ENABLE_TENSORFLOW END
463483
llvm_unreachable("bad kind");
464484
}
465485

@@ -588,6 +608,12 @@ AbstractionPattern::getFunctionParamType(unsigned index) const {
588608
params[index].getParameterType(),
589609
getClangFunctionParameterType(getClangType(), index));
590610
}
611+
// SWIFT_ENABLE_TENSORFLOW
612+
case Kind::OpaqueFunction:
613+
return getOpaque();
614+
case Kind::OpaqueDerivativeFunction:
615+
return getOpaque();
616+
// SWIFT_ENABLE_TENSORFLOW END
591617
default:
592618
llvm_unreachable("does not have function parameters");
593619
}
@@ -617,6 +643,10 @@ AbstractionPattern AbstractionPattern::getOptionalObjectType() const {
617643
case Kind::CurriedCXXMethodType:
618644
case Kind::PartialCurriedCXXMethodType:
619645
case Kind::Tuple:
646+
// SWIFT_ENABLE_TENSORFLOW
647+
case Kind::OpaqueFunction:
648+
case Kind::OpaqueDerivativeFunction:
649+
// SWIFT_ENABLE_TENSORFLOW END
620650
llvm_unreachable("pattern for function or tuple cannot be for optional");
621651

622652
case Kind::Opaque:
@@ -658,7 +688,11 @@ AbstractionPattern AbstractionPattern::getReferenceStorageReferentType() const {
658688
case Kind::CurriedCXXMethodType:
659689
case Kind::PartialCurriedCXXMethodType:
660690
case Kind::Tuple:
691+
// SWIFT_ENABLE_TENSORFLOW
692+
case Kind::OpaqueFunction:
693+
case Kind::OpaqueDerivativeFunction:
661694
return *this;
695+
// SWIFT_ENABLE_TENSORFLOW END
662696
case Kind::Type:
663697
return AbstractionPattern(getGenericSignature(),
664698
getType().getReferenceStorageReferent());
@@ -687,6 +721,14 @@ void AbstractionPattern::print(raw_ostream &out) const {
687721
case Kind::Opaque:
688722
out << "AP::Opaque";
689723
return;
724+
// SWIFT_ENABLE_TENSORFLOW
725+
case Kind::OpaqueFunction:
726+
out << "AP::OpaqueFunction";
727+
return;
728+
case Kind::OpaqueDerivativeFunction:
729+
out << "AP::OpaqueDerivativeFunction";
730+
return;
731+
// SWIFT_ENABLE_TENSORFLOW END
690732
case Kind::Type:
691733
case Kind::Discard:
692734
out << (getKind() == Kind::Type
@@ -850,6 +892,14 @@ const {
850892
case Kind::Tuple:
851893
llvm_unreachable("should not have a tuple pattern matching a struct/enum "
852894
"type");
895+
// SWIFT_ENABLE_TENSORFLOW
896+
case Kind::OpaqueFunction:
897+
llvm_unreachable("should not have an opaque function pattern matching a "
898+
"struct/enum type");
899+
case Kind::OpaqueDerivativeFunction:
900+
llvm_unreachable("should not have an opaque derivative function pattern "
901+
"matching a struct/enum type");
902+
// SWIFT_ENABLE_TENSORFLOW END
853903
case Kind::PartialCurriedObjCMethodType:
854904
case Kind::CurriedObjCMethodType:
855905
case Kind::PartialCurriedCFunctionAsMethodType:
@@ -869,3 +919,26 @@ const {
869919
return AbstractionPattern(getGenericSignature(), memberTy);
870920
}
871921
}
922+
923+
AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
924+
IndexSubset *indices, unsigned resultIndex,
925+
AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
926+
GenericSignature derivativeGenericSignature, bool makeSelfParamFirst) {
927+
switch (getKind()) {
928+
case Kind::Type: {
929+
auto fnTy = dyn_cast<AnyFunctionType>(getType());
930+
if (!fnTy)
931+
return getOpaqueDerivativeFunction();
932+
auto derivativeFnTy = fnTy->getAutoDiffDerivativeFunctionType(
933+
indices, resultIndex, kind, lookupConformance,
934+
derivativeGenericSignature, makeSelfParamFirst);
935+
assert(derivativeFnTy);
936+
return AbstractionPattern(getGenericSignature(),
937+
derivativeFnTy->getCanonicalType());
938+
}
939+
case Kind::Opaque:
940+
return getOpaqueDerivativeFunction();
941+
default:
942+
llvm_unreachable("called on unsupported abstraction pattern kind");
943+
}
944+
}

lib/SIL/SILFunctionType.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,12 @@ class DestructureResults {
913913
|| substTL.isAddressOnly()) {
914914
return true;
915915

916+
// SWIFT_ENABLE_TENSORFLOW
917+
// Functions are always returned directly.
918+
} else if (origType.isOpaqueFunctionOrOpaqueDerivativeFunction()) {
919+
return false;
920+
// SWIFT_ENABLE_TENSORFLOW END
921+
916922
// If the substitution didn't change the type, then a negative
917923
// response to the above is determinative as well.
918924
} else if (origType.getType() == substType &&

0 commit comments

Comments
 (0)