Skip to content

Commit 83c7155

Browse files
committed
More cleanup.
Track TF-1197 and TF-1198.
1 parent a31a87a commit 83c7155

File tree

5 files changed

+70
-48
lines changed

5 files changed

+70
-48
lines changed

include/swift/SILOptimizer/Utils/Differentiation/Thunk.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
114114
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
115115
SILAutoDiffIndices actualIndices);
116116

117-
/// Reabstract the given function-typed value to the given target type.
117+
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
118+
/// Remaps substitutions using `remapSubstitutions`.
118119
SILValue reabstractFunction(
119120
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
120121
SILValue fn, CanSILFunctionType toType,
121-
std::function<SubstitutionMap(SubstitutionMap)> remapSubstMap);
122+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
122123

123124
} // end namespace autodiff
124125

lib/SIL/SILFunctionType.cpp

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,49 +1029,56 @@ class DestructureInputs {
10291029

10301030
} // end anonymous namespace
10311031

1032-
static SmallVector<SILParameterInfo, 4>
1032+
/// Collects the differentiability parameters of the given original function
1033+
/// type in `diffParams`.
1034+
static void
10331035
getDifferentiabilityParameters(SILFunctionType *originalFnTy,
1034-
IndexSubset *parameterIndices) {
1036+
IndexSubset *parameterIndices,
1037+
SmallVectorImpl<SILParameterInfo> &diffParams) {
10351038
// Returns true if `index` is a differentiability parameter index.
10361039
auto isDiffParamIndex = [&](unsigned index) -> bool {
10371040
return index < parameterIndices->getCapacity() &&
10381041
parameterIndices->contains(index);
10391042
};
10401043
// Calculate differentiability parameter infos.
1041-
SmallVector<SILParameterInfo, 4> diffParams;
10421044
for (auto valueAndIndex : enumerate(originalFnTy->getParameters()))
10431045
if (isDiffParamIndex(valueAndIndex.index()))
10441046
diffParams.push_back(valueAndIndex.value());
1045-
return diffParams;
10461047
}
10471048

1048-
static SmallVector<SILResultInfo, 2>
1049-
getSemanticResults(SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
1049+
/// Collects the semantic results of the given function type in
1050+
/// `originalResults`. The semantic results are formal results followed by
1051+
/// `inout` parameters, in type order.
1052+
// TODO(TF-983): Generalize to support multiple `inout` parameters. The current
1053+
// singular `inoutParam` and `isWrtInoutParameter` are hacky.
1054+
static void
1055+
getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
10501056
Optional<SILParameterInfo> &inoutParam,
1051-
bool &isWrtInoutParameter) {
1052-
// Compute the original semantic results: original formal results, followed by
1053-
// `inout` parameters in type order.
1054-
SmallVector<SILResultInfo, 2> originalResults;
1057+
bool &isWrtInoutParameter,
1058+
SmallVectorImpl<SILResultInfo> &originalResults) {
10551059
inoutParam = None;
10561060
isWrtInoutParameter = false;
1057-
originalResults.append(originalFnTy->getResults().begin(),
1058-
originalFnTy->getResults().end());
1059-
for (auto i : range(originalFnTy->getNumParameters())) {
1060-
auto param = originalFnTy->getParameters()[i];
1061+
// Collect original formal results.
1062+
originalResults.append(functionType->getResults().begin(),
1063+
functionType->getResults().end());
1064+
// Collect original `inout` parameters.
1065+
for (auto i : range(functionType->getNumParameters())) {
1066+
auto param = functionType->getParameters()[i];
10611067
if (!param.isIndirectInOut())
10621068
continue;
10631069
inoutParam = param;
10641070
isWrtInoutParameter = parameterIndices->contains(i);
10651071
originalResults.push_back(
10661072
SILResultInfo(param.getInterfaceType(), ResultConvention::Indirect));
10671073
}
1068-
return originalResults;
10691074
}
10701075

1076+
/// Returns the differential type for the given original function type,
1077+
/// parameter indices, and result index.
10711078
static CanSILFunctionType
1072-
getDifferentialType(SILFunctionType *originalFnTy,
1073-
IndexSubset *parameterIndices, unsigned resultIndex,
1074-
LookupConformanceFn lookupConformance) {
1079+
getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
1080+
IndexSubset *parameterIndices, unsigned resultIndex,
1081+
LookupConformanceFn lookupConformance) {
10751082
auto &ctx = originalFnTy->getASTContext();
10761083
SmallVector<GenericTypeParamType *, 4> substGenericParams;
10771084
SmallVector<Requirement, 4> substRequirements;
@@ -1080,12 +1087,14 @@ getDifferentialType(SILFunctionType *originalFnTy,
10801087

10811088
Optional<SILParameterInfo> inoutParam = None;
10821089
bool isWrtInoutParameter = false;
1083-
SmallVector<SILResultInfo, 2> originalResults = getSemanticResults(
1084-
originalFnTy, parameterIndices, inoutParam, isWrtInoutParameter);
1090+
SmallVector<SILResultInfo, 2> originalResults;
1091+
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
1092+
isWrtInoutParameter, originalResults);
10851093

1094+
SmallVector<SILParameterInfo, 4> diffParams;
1095+
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
10861096
SmallVector<SILParameterInfo, 8> differentialParams;
1087-
for (auto &param :
1088-
getDifferentiabilityParameters(originalFnTy, parameterIndices)) {
1097+
for (auto &param : diffParams) {
10891098
auto paramTan =
10901099
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
10911100
assert(paramTan && "Parameter type does not have a tangent space?");
@@ -1136,11 +1145,13 @@ getDifferentialType(SILFunctionType *originalFnTy,
11361145
differentialResults, None, substitutions, impliedSignature, ctx);
11371146
}
11381147

1139-
static CanSILFunctionType getPullbackType(SILFunctionType *originalFnTy,
1140-
IndexSubset *parameterIndices,
1141-
unsigned resultIndex,
1142-
LookupConformanceFn lookupConformance,
1143-
TypeConverter &TC) {
1148+
/// Returns the pullback type for the given original function type, parameter
1149+
/// indices, and result index.
1150+
static CanSILFunctionType
1151+
getAutoDiffPullbackType(SILFunctionType *originalFnTy,
1152+
IndexSubset *parameterIndices, unsigned resultIndex,
1153+
LookupConformanceFn lookupConformance,
1154+
TypeConverter &TC) {
11441155
auto &ctx = originalFnTy->getASTContext();
11451156
SmallVector<GenericTypeParamType *, 4> substGenericParams;
11461157
SmallVector<Requirement, 4> substRequirements;
@@ -1149,8 +1160,9 @@ static CanSILFunctionType getPullbackType(SILFunctionType *originalFnTy,
11491160

11501161
Optional<SILParameterInfo> inoutParam = None;
11511162
bool isWrtInoutParameter = false;
1152-
SmallVector<SILResultInfo, 2> originalResults = getSemanticResults(
1153-
originalFnTy, parameterIndices, inoutParam, isWrtInoutParameter);
1163+
SmallVector<SILResultInfo, 2> originalResults;
1164+
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
1165+
isWrtInoutParameter, originalResults);
11541166

11551167
// Given a type, returns its formal SIL parameter info.
11561168
auto getTangentParameterConventionForOriginalResult =
@@ -1251,9 +1263,10 @@ static CanSILFunctionType getPullbackType(SILFunctionType *originalFnTy,
12511263
pullbackParams.push_back({gpType, paramTanConvention});
12521264
}
12531265
}
1266+
SmallVector<SILParameterInfo, 4> diffParams;
1267+
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
12541268
SmallVector<SILResultInfo, 8> pullbackResults;
1255-
for (auto &param :
1256-
getDifferentiabilityParameters(originalFnTy, parameterIndices)) {
1269+
for (auto &param : diffParams) {
12571270
if (param.isIndirectInOut())
12581271
continue;
12591272
auto paramTan =
@@ -1378,12 +1391,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
13781391
CanSILFunctionType closureType;
13791392
switch (kind) {
13801393
case AutoDiffDerivativeFunctionKind::JVP:
1381-
closureType = getDifferentialType(constrainedOriginalFnTy, parameterIndices,
1382-
resultIndex, lookupConformance);
1394+
closureType =
1395+
getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices,
1396+
resultIndex, lookupConformance);
13831397
break;
13841398
case AutoDiffDerivativeFunctionKind::VJP:
1385-
closureType = getPullbackType(constrainedOriginalFnTy, parameterIndices,
1386-
resultIndex, lookupConformance, TC);
1399+
closureType =
1400+
getAutoDiffPullbackType(constrainedOriginalFnTy, parameterIndices,
1401+
resultIndex, lookupConformance, TC);
13871402
break;
13881403
}
13891404

lib/SIL/SILVerifier.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4582,8 +4582,10 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
45824582

45834583
// SWIFT_ENABLE_TENSORFLOW
45844584
void checkDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
4585-
#warning We should re-enable `differentiable_function` verification before landing this
4585+
// FIXME(TF-1197): Re-enable verification after substituted SIL function
4586+
// types.
45864587
return;
4588+
#if 0
45874589
auto origTy =
45884590
dfi->getOriginalFunction()->getType().getAs<SILFunctionType>();
45894591
require(origTy, "The original function must have a function type");
@@ -4624,6 +4626,7 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
46244626
SILType::getPrimitiveObjectType(expectedVJPType),
46254627
"VJP type does not match expected VJP type");
46264628
}
4629+
#endif
46274630
}
46284631

46294632
void checkLinearFunctionInst(LinearFunctionInst *lfi) {
@@ -5455,8 +5458,10 @@ void SILGlobalVariable::verify() const {
54555458
// SWIFT_ENABLE_TENSORFLOW
54565459
/// Verify that a differentiability witness follows invariants.
54575460
void SILDifferentiabilityWitness::verify(const SILModule &M) const {
5458-
#warning We should re-enable `differentiable_function` verification before landing this
5459-
return;
5461+
// FIXME(TF-1197): Re-enable verification after substituted SIL function
5462+
// types.
5463+
return;
5464+
#if 0
54605465
#ifdef NDEBUG
54615466
if (!M.getOptions().VerifyAll)
54625467
return;
@@ -5510,6 +5515,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
55105515
requireSameType(vjp->getLoweredFunctionType(), expectedVJPType,
55115516
"VJP type does not match expected VJP type");
55125517
}
5518+
#endif
55135519
}
55145520
// SWIFT_ENABLE_TENSORFLOW END
55155521

lib/SILGen/SILGenPoly.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,9 +1596,12 @@ static ManagedValue applyTrivialConversions(SILGenFunction &SGF,
15961596
auto innerASTTy = innerValue.getType().getASTType();
15971597
auto outerASTTy = outerType.getASTType();
15981598
// SWIFT_ENABLE_TENSORFLOW
1599-
// Mapping out of context is necessary for
1599+
// Mapping out of context is necessary to avoid assertion failures for
16001600
// `SILGenModule::getOrCreateCustomDerivativeThunk`.
1601-
// Consider finding a robust fix.
1601+
// FIXME(TF-1198): Find a robust fix and remove this hack.
1602+
// Thunk type calculation in `SILGenModule::getOrCreateCustomDerivativeThunk`
1603+
// may be missing logic from `SILGenFunction::buildThunkType` that maps
1604+
// archetypes to interface types.
16021605
if (innerASTTy->hasArchetype())
16031606
innerASTTy = innerASTTy->mapTypeOutOfContext()->getCanonicalType();
16041607
if (outerASTTy->hasArchetype())

lib/SILOptimizer/Utils/Differentiation/Thunk.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -845,10 +845,7 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
845845
SILValue reabstractFunction(
846846
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
847847
SILValue fn, CanSILFunctionType toType,
848-
std::function<SubstitutionMap(SubstitutionMap)> remapSubstMap) {
849-
// TODO: I removed some calls to getOpType and getOpSubstitutionMap because we
850-
// don't have a cloner here. Maybe I should have a cloner so that I can call
851-
// them? Also remapSubstitutionMap from other callers.
848+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
852849
auto &module = *fn->getModule();
853850
auto fromType = fn->getType().getAs<SILFunctionType>();
854851
auto unsubstFromType = fromType->getUnsubstitutedType(module);
@@ -865,8 +862,8 @@ SILValue reabstractFunction(
865862
/*withoutActuallyEscaping*/ false);
866863

867864
fn = builder.createPartialApply(
868-
loc, thunkRef, remapSubstMap(thunk->getForwardingSubstitutionMap()), {fn},
869-
fromType->getCalleeConvention());
865+
loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
866+
{fn}, fromType->getCalleeConvention());
870867

871868
if (toType != unsubstToType)
872869
fn = builder.createConvertFunction(loc, fn,

0 commit comments

Comments
 (0)