Skip to content

Commit 521c17d

Browse files
authored
Merge pull request #30312 from apple/tensorflow-stage-wip
Fix AutoDiff for substituted SIL function types. The TF-1196 master issue tracks remaining improvements and fixes.
2 parents 647a276 + 83c7155 commit 521c17d

File tree

199 files changed

+287686
-282423
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

199 files changed

+287686
-282423
lines changed

include/swift/AST/Decl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4710,6 +4710,10 @@ class AbstractStorageDecl : public ValueDecl {
47104710
/// To ensure an accessor is always returned, use getSynthesizedAccessor().
47114711
AccessorDecl *getOpaqueAccessor(AccessorKind kind) const;
47124712

4713+
/// Collect all opaque accessors.
4714+
ArrayRef<AccessorDecl*>
4715+
getOpaqueAccessors(llvm::SmallVectorImpl<AccessorDecl*> &scratch) const;
4716+
47134717
/// Return an accessor that was written in source. Returns null if the
47144718
/// accessor was not explicitly defined by the user.
47154719
AccessorDecl *getParsedAccessor(AccessorKind kind) const;

include/swift/AST/DiagnosticsParse.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ ERROR(sil_box_expected_r_angle,none,
824824
ERROR(sil_function_subst_expected_l_angle,none,
825825
"expected '<' to begin SIL function type substitution list after 'for'", ())
826826
ERROR(sil_function_subst_expected_r_angle,none,
827-
"expected '>' to begin SIL function type substitution list after 'for'", ())
827+
"expected '>' to end SIL function type substitution list after 'for <...'", ())
828828

829829
// Opaque types
830830
ERROR(opaque_mid_composition,none,

include/swift/AST/Type.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class NormalProtocolConformance;
4747
class ProtocolConformanceRef;
4848
class ProtocolDecl;
4949
class ProtocolType;
50+
class SILModule;
5051
class StructDecl;
5152
class SubstitutableType;
5253
class SubstitutionMap;

include/swift/AST/Types.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -899,14 +899,26 @@ class alignas(1 << TypeAlignInBits) TypeBase {
899899
/// Visit this type and the argument type in parallel, invoking the callback
900900
/// function with each archetype-to-substituted-type binding. The callback
901901
/// may return a new type to substitute into the result type, or return
902-
/// CanType() to error out of the operation.
902+
/// CanType() to error out of the operation. Each invocation of the callback
903+
/// receives three arguments:
904+
/// - The `orig` archetype from a position in `this` type.
905+
/// - The `subst` type in the same structural position of `ty` that is trying to be bound
906+
/// to `orig`.
907+
/// - The `upperBound` archetype, which if set, indicates the minimum set of constraints
908+
/// that any type substituted in this structural position must conform to. May be null,
909+
/// indicating an unconstrained context.
910+
/// - If `upperBound` is set, then the `substConformances` array will contain the
911+
/// protocol conformances for `subst` to each of the protocol requirements
912+
/// on `upperBound` in `getConformsTo` order.
903913
///
904914
/// Returns the substituted type, or a null CanType() if this type
905915
/// is not bindable to the substituted type, or the callback returns
906916
/// CanType().
907917
CanType substituteBindingsTo(Type ty,
908-
llvm::function_ref<CanType(ArchetypeType*, CanType)> substFn);
909-
918+
llvm::function_ref<CanType(ArchetypeType *orig,
919+
CanType subst,
920+
ArchetypeType *upperBound,
921+
ArrayRef<ProtocolConformanceRef> substConformances)> substFn);
910922

911923
/// Determines whether this type is similar to \p other as defined by
912924
/// \p matchOptions.
@@ -1168,6 +1180,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
11681180
/// object type.
11691181
TypeTraitResult canBeClass();
11701182

1183+
Type replaceSubstitutedSILFunctionTypesWithUnsubstituted(SILModule &M) const; // in SILType.cpp
1184+
11711185
/// Return the tangent space of the given type, if it exists. Otherwise,
11721186
/// return `None`.
11731187
Optional<TangentSpace>
@@ -3977,7 +3991,7 @@ class SILResultInfo {
39773991
}
39783992

39793993
ValueOwnershipKind
3980-
getOwnershipKind(SILFunction &) const; // in SILType.cpp
3994+
getOwnershipKind(SILFunction &, CanSILFunctionType fTy) const; // in SILType.cpp
39813995

39823996
bool operator==(SILResultInfo rhs) const {
39833997
return TypeAndConvention == rhs.TypeAndConvention;

include/swift/Basic/Algorithm.h

Lines changed: 0 additions & 35 deletions
This file was deleted.

include/swift/Basic/FlaggedPointer.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
#ifndef SWIFT_BASIC_FLAGGEDPOINTER_H
1818
#define SWIFT_BASIC_FLAGGEDPOINTER_H
1919

20+
#include <algorithm>
2021
#include <cassert>
2122

2223
#include "llvm/Support/Compiler.h"
2324
#include "llvm/Support/PointerLikeTypeTraits.h"
2425

25-
#include "Algorithm.h"
26-
2726
namespace swift {
2827

2928
/// This class implements a pair of a pointer and boolean flag.
@@ -170,7 +169,7 @@ struct llvm::PointerLikeTypeTraits<
170169
enum {
171170
NumLowBitsAvailable = (BitPosition >= PtrTraits::NumLowBitsAvailable)
172171
? PtrTraits::NumLowBitsAvailable
173-
: (swift::min(int(BitPosition + 1),
172+
: (std::min(int(BitPosition + 1),
174173
int(PtrTraits::NumLowBitsAvailable)) - 1)
175174
};
176175
};

include/swift/Basic/LangOptions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ namespace swift {
194194
bool BuildRequestDependencyGraph = false;
195195

196196
/// Enable SIL type lowering
197-
bool EnableSubstSILFunctionTypesForFunctionValues = false;
197+
bool EnableSubstSILFunctionTypesForFunctionValues = true;
198198

199199
/// Whether to diagnose an ephemeral to non-ephemeral conversion as an
200200
/// error.

include/swift/Basic/PrefixMap.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
#ifndef SWIFT_BASIC_PREFIXMAP_H
3535
#define SWIFT_BASIC_PREFIXMAP_H
3636

37-
#include "swift/Basic/Algorithm.h"
3837
#include "swift/Basic/Debug.h"
3938
#include "swift/Basic/LLVM.h"
4039
#include "swift/Basic/type_traits.h"
@@ -53,8 +52,8 @@ template <class KeyElementType> class PrefixMapKeyPrinter;
5352
/// A map whose keys are sequences of comparable values, optimized for
5453
/// finding a mapped value for the longest matching initial subsequence.
5554
template <class KeyElementType, class ValueType,
56-
size_t InlineKeyCapacity
57-
= max<size_t>((sizeof(void*) - 1) / sizeof(KeyElementType), 1)>
55+
size_t InlineKeyCapacity = std::max(
56+
(sizeof(void *) - 1) / sizeof(KeyElementType), size_t(1))>
5857
class PrefixMap {
5958
public:
6059
using KeyType = ArrayRef<KeyElementType>;

include/swift/SIL/ApplySite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ class ApplySite {
195195
SILType getSubstCalleeSILType() const {
196196
FOREACH_IMPL_RETURN(getSubstCalleeSILType());
197197
}
198+
void setSubstCalleeType(CanSILFunctionType t) {
199+
FOREACH_IMPL_RETURN(setSubstCalleeType(t));
200+
}
201+
198202
/// Get the conventions of the callee with the applied substitutions.
199203
SILFunctionConventions getSubstCalleeConv() const {
200204
return SILFunctionConventions(getSubstCalleeType(), getModule());

include/swift/SIL/SILCloner.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2859,10 +2859,13 @@ template<typename ImplClass>
28592859
void SILCloner<ImplClass>::
28602860
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst) {
28612861
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2862+
Optional<SILType> explicitExtracteeType = None;
2863+
if (Inst->hasExplicitExtracteeType())
2864+
explicitExtracteeType = Inst->getType();
28622865
recordClonedInstruction(
28632866
Inst, getBuilder().createDifferentiableFunctionExtract(
28642867
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
2865-
getOpValue(Inst->getFunctionOperand())));
2868+
getOpValue(Inst->getFunctionOperand()), explicitExtracteeType));
28662869
}
28672870

28682871
template<typename ImplClass>

include/swift/SIL/SILInstruction.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,11 @@ class ApplyInstBase<Impl, Base, false> : public Base {
19331933
SILType getSubstCalleeSILType() const {
19341934
return SubstCalleeType;
19351935
}
1936+
1937+
void setSubstCalleeType(CanSILFunctionType t) {
1938+
SubstCalleeType = SILType::getPrimitiveObjectType(t);
1939+
}
1940+
19361941
SILFunctionConventions getSubstCalleeConv() const {
19371942
return SILFunctionConventions(getSubstCalleeType(), this->getModule());
19381943
}
@@ -4352,6 +4357,11 @@ class ConvertFunctionInst final
43524357
bool withoutActuallyEscaping() const {
43534358
return SILInstruction::Bits.ConvertFunctionInst.WithoutActuallyEscaping;
43544359
}
4360+
4361+
/// Returns `true` if the function conversion is between types with the same
4362+
/// argument and return types, as well as all other attributes, after substitution,
4363+
/// such as converting `$<A, B> in (A) -> B for <Int, String>` to `(Int) -> String`.
4364+
bool onlyConvertsSubstitutions() const;
43554365
};
43564366

43574367
/// ConvertEscapeToNoEscapeInst - Change the type of a escaping function value

include/swift/SIL/TypeSubstCloner.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,65 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
315315
super::visitDestroyValueInst(Destroy);
316316
}
317317

318+
// SWIFT_ENABLE_TENSORFLOW
319+
void visitDifferentiableFunctionExtractInst(
320+
DifferentiableFunctionExtractInst *dfei) {
321+
// If the extractee is the original function, do regular cloning.
322+
if (dfei->getExtractee() ==
323+
NormalDifferentiableFunctionTypeComponent::Original) {
324+
super::visitDifferentiableFunctionExtractInst(dfei);
325+
return;
326+
}
327+
// If the extractee is a derivative function, check whether the *remapped
328+
// derivative function type* (BC) is equal to the *derivative remapped
329+
// function type* (AD).
330+
//
331+
// +----------------+ remap +-------------------------+
332+
// | orig. fn type | -------(A)------> | remapped orig. fn type |
333+
// +----------------+ +-------------------------+
334+
// | |
335+
// (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
336+
// V V
337+
// +----------------+ remap +-------------------------+
338+
// | deriv. fn type | -------(C)------> | remapped deriv. fn type |
339+
// +----------------+ +-------------------------+
340+
//
341+
// (AD) does not always commute with (BC):
342+
// - (AD) is the result of remapping, then computing the derivative type.
343+
// This is the default cloning behavior, but may break invariants in the
344+
// initial SIL generated by SILGen.
345+
// - (BC) is the result of computing the derivative type (SILGen), then
346+
// remapping. This is the expected type, preserving invariants from
347+
// earlier transforms.
348+
//
349+
// If (AD) is not equal to (BC), use (BC) as the explicit type.
350+
SILType remappedOrigType = getOpType(dfei->getFunctionOperand()->getType());
351+
auto remappedOrigFnType = remappedOrigType.castTo<SILFunctionType>();
352+
auto derivativeRemappedFnType =
353+
remappedOrigFnType
354+
->getAutoDiffDerivativeFunctionType(
355+
remappedOrigFnType->getDifferentiabilityParameterIndices(),
356+
/*resultIndex*/ 0, dfei->getDerivativeFunctionKind(),
357+
getBuilder().getModule().Types,
358+
LookUpConformanceInModule(SwiftMod))
359+
->getWithoutDifferentiability();
360+
SILType remappedDerivativeFnType = getOpType(dfei->getType());
361+
// If remapped derivative type and derivative remapped type are equal, do
362+
// regular cloning.
363+
if (SILType::getPrimitiveObjectType(derivativeRemappedFnType) ==
364+
remappedDerivativeFnType) {
365+
super::visitDifferentiableFunctionExtractInst(dfei);
366+
return;
367+
}
368+
// Otherwise, explicitly use the remapped derivative type.
369+
recordClonedInstruction(
370+
dfei,
371+
getBuilder().createDifferentiableFunctionExtract(
372+
getOpLocation(dfei->getLoc()), dfei->getExtractee(),
373+
getOpValue(dfei->getFunctionOperand()), remappedDerivativeFnType));
374+
}
375+
// SWIFT_ENABLE_TENSORFLOW END
376+
318377
/// One abstract function in the debug info can only have one set of variables
319378
/// and types. This function determines whether applying the substitutions in
320379
/// \p SubsMap on the generic signature \p Sig will change the generic type

include/swift/SILOptimizer/Utils/Differentiation/Thunk.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "swift/AST/AutoDiff.h"
2121
#include "swift/Basic/LLVM.h"
22+
#include "swift/SIL/SILBuilder.h"
2223

2324
namespace swift {
2425

@@ -113,6 +114,13 @@ getOrCreateSubsetParametersThunkForLinearMap(
113114
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
114115
SILAutoDiffIndices actualIndices);
115116

117+
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
118+
/// Remaps substitutions using `remapSubstitutions`.
119+
SILValue reabstractFunction(
120+
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
121+
SILValue fn, CanSILFunctionType toType,
122+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
123+
116124
} // end namespace autodiff
117125

118126
} // end namespace swift

lib/AST/ASTContext.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,17 @@ using AssociativityCacheType =
101101

102102
struct OverrideSignatureKey {
103103
GenericSignature baseMethodSig;
104-
GenericSignature derivedClassSig;
105-
Type superclassTy;
104+
GenericSignature derivedMethodSig;
105+
Type superclassTy, subclassTy;
106106

107107
OverrideSignatureKey(GenericSignature baseMethodSignature,
108-
GenericSignature derivedClassSignature,
109-
Type superclassType)
110-
: baseMethodSig(baseMethodSignature),
111-
derivedClassSig(derivedClassSignature), superclassTy(superclassType) {}
108+
GenericSignature derivedMethodSignature,
109+
Type superclassType,
110+
Type subclassType)
111+
: baseMethodSig(baseMethodSignature),
112+
derivedMethodSig(derivedMethodSignature),
113+
superclassTy(superclassType),
114+
subclassTy(subclassType) {}
112115
};
113116

114117
namespace llvm {
@@ -119,28 +122,32 @@ template <> struct DenseMapInfo<OverrideSignatureKey> {
119122
static bool isEqual(const OverrideSignatureKey lhs,
120123
const OverrideSignatureKey rhs) {
121124
return lhs.baseMethodSig.getPointer() == rhs.baseMethodSig.getPointer() &&
122-
lhs.derivedClassSig.getPointer() == rhs.derivedClassSig.getPointer() &&
123-
lhs.superclassTy.getPointer() == rhs.superclassTy.getPointer();
125+
lhs.derivedMethodSig.getPointer() == rhs.derivedMethodSig.getPointer() &&
126+
lhs.superclassTy.getPointer() == rhs.superclassTy.getPointer() &&
127+
lhs.subclassTy.getPointer() == rhs.subclassTy.getPointer();
124128
}
125129

126130
static inline OverrideSignatureKey getEmptyKey() {
127131
return OverrideSignatureKey(DenseMapInfo<GenericSignature>::getEmptyKey(),
128132
DenseMapInfo<GenericSignature>::getEmptyKey(),
133+
DenseMapInfo<Type>::getEmptyKey(),
129134
DenseMapInfo<Type>::getEmptyKey());
130135
}
131136

132137
static inline OverrideSignatureKey getTombstoneKey() {
133138
return OverrideSignatureKey(
134139
DenseMapInfo<GenericSignature>::getTombstoneKey(),
135140
DenseMapInfo<GenericSignature>::getTombstoneKey(),
141+
DenseMapInfo<Type>::getTombstoneKey(),
136142
DenseMapInfo<Type>::getTombstoneKey());
137143
}
138144

139145
static unsigned getHashValue(const OverrideSignatureKey &Val) {
140146
return hash_combine(
141147
DenseMapInfo<GenericSignature>::getHashValue(Val.baseMethodSig),
142-
DenseMapInfo<GenericSignature>::getHashValue(Val.derivedClassSig),
143-
DenseMapInfo<Type>::getHashValue(Val.superclassTy));
148+
DenseMapInfo<GenericSignature>::getHashValue(Val.derivedMethodSig),
149+
DenseMapInfo<Type>::getHashValue(Val.superclassTy),
150+
DenseMapInfo<Type>::getHashValue(Val.subclassTy));
144151
}
145152
};
146153
} // namespace llvm
@@ -3563,7 +3570,9 @@ CanSILFunctionType SILFunctionType::get(
35633570
assert(coroutineKind == SILCoroutineKind::None || normalResults.empty());
35643571
assert(coroutineKind != SILCoroutineKind::None || yields.empty());
35653572
assert(!ext.isPseudogeneric() || genericSig);
3566-
3573+
3574+
substitutions = substitutions.getCanonical();
3575+
35673576
llvm::FoldingSetNodeID id;
35683577
SILFunctionType::Profile(id, genericSig, ext, coroutineKind, callee, params,
35693578
yields, normalResults, errorResult,
@@ -4679,8 +4688,9 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
46794688
unsigned derivedDepth = 0;
46804689

46814690
auto key = OverrideSignatureKey(baseGenericCtx->getGenericSignature(),
4682-
derivedClass->getGenericSignature(),
4683-
derivedClass->getSuperclass());
4691+
derivedGenericCtx->getGenericSignature(),
4692+
derivedClass->getSuperclass(),
4693+
derivedClass->getDeclaredInterfaceType());
46844694

46854695
if (getImpl().overrideSigCache.find(key) !=
46864696
getImpl().overrideSigCache.end()) {

0 commit comments

Comments
 (0)