Skip to content

Commit d39eaf4

Browse files
committed
[AutoDiff] Fix derivative generic signature same-type requirements.
Fix derivative generic signature calculation when same-type requirements bind all generic parameters to concrete types, i.e. when all generic parameters are concrete. Declarations whose generic signature have all concrete generic parameters are lowered as SIL functions with no generic signature: they are specialized with the concrete types from the same-type requirements. For `@differentiable` attributes: when the original generic signature and the derivative generic signature are equal and all generic parameters are concrete, do not set the attribute's derivative generic signature. Update SIL infrastructure to handle derivative generic signatures with all concrete generic parameters. In such cases: - SIL derivative function types are specialized with concrete types and have no generic signature. - SIL differentiability witnesses have a derivative generic signature iff it differs from the original generic signature. Witness generic signatures should be used for remapping types during the differentiation transform. Resolves TF-1059 and TF-1062.
1 parent e5915f7 commit d39eaf4

File tree

16 files changed

+349
-103
lines changed

16 files changed

+349
-103
lines changed

include/swift/AST/Types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4466,7 +4466,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44664466
/// This "constrained derivative generic signature" is used for
44674467
/// parameter/result type lowering. It is used as the actual generic signature
44684468
/// of the derivative function type iff the original function type has a
4469-
/// generic signature; otherwise, no derivative generic signature is used.
4469+
/// generic signature and not all generic parameters are bound to concrete
4470+
/// types. Otherwise, no derivative generic signature is used.
44704471
///
44714472
/// Other properties of the original function type are copied exactly:
44724473
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class SILDifferentiabilityWitness
4949
/// The original function.
5050
SILFunction *OriginalFunction;
5151
/// The autodiff configuration: parameter indices, result indices, derivative
52-
/// generic signature (optional).
52+
/// generic signature (optional). The derivative generic signature may contain
53+
/// same-type requirements such that all generic parameters are bound to
54+
/// concrete types.
5355
AutoDiffConfig Config;
5456
/// The JVP (Jacobian-vector products) derivative function.
5557
SILFunction *JVP;

include/swift/SILOptimizer/Utils/Differentiation/Common.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,6 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
133133
void forEachApplyDirectResult(
134134
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback);
135135

136-
/// Returns the canonical derivative generic signature for the given witness
137-
/// and original function.
138-
/// - Return the witness derivative generic signature if it exists.
139-
/// - Otherwise, return the original function's generic signature.
140-
CanGenericSignature
141-
getDerivativeGenericSignature(SILDifferentiabilityWitness *witness,
142-
SILFunction *original);
143-
144136
/// Given a function, gathers all of its formal results (both direct and
145137
/// indirect) in an order defined by its result type. Note that "formal results"
146138
/// refer to result values in the body of the function, not at call sites.

lib/SIL/SILFunctionType.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
391391

392392
SmallVector<SILParameterInfo, 4> newParameters;
393393
newParameters.reserve(getNumParameters());
394-
newParameters.append(getParameters().begin(), getParameters().end());
394+
for (auto &param : getParameters()) {
395+
newParameters.push_back(param.getWithInterfaceType(
396+
param.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
397+
}
395398
// Reabstraction thunks have a function-typed parameter (the function to
396399
// reabstract) as their last parameter. Reabstraction thunk JVPs/VJPs have a
397400
// `@differentiable` function-typed last parameter instead.
@@ -414,9 +417,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
414417
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
415418
ResultConvention::Owned});
416419
// Derivative function type has a generic signature only if the original
417-
// function type does.
420+
// function type does, and if `derivativeFnGenSig` does not have all concrete
421+
// generic parameters.
418422
CanGenericSignature canGenSig;
419-
if (getSubstGenericSignature())
423+
if (getSubstGenericSignature() && derivativeFnGenSig &&
424+
!derivativeFnGenSig->areAllParamsConcrete())
420425
canGenSig = derivativeFnGenSig;
421426
return SILFunctionType::get(canGenSig, getExtInfo(), getCoroutineKind(),
422427
getCalleeConvention(), newParameters, getYields(),

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
776776
if (auto *vjpDecl = diffAttr->getVJPFunction())
777777
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
778778
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
779-
assert((!AFD->getGenericSignature() || diffAttr->getDerivativeGenericSignature()) &&
780-
"type-checking should resolve derivative generic signatures for "
781-
"all functions with generic signatures");
779+
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
780+
diffAttr->getDerivativeGenericSignature()) &&
781+
"Type-checking should resolve derivative generic signatures for "
782+
"all original SIL functions with generic signatures");
782783
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
783784
diffAttr->getDerivativeGenericSignature());
784785
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);

lib/SILGen/SILGenPoly.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3734,10 +3734,13 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
37343734
auto origFnType = original->getLoweredFunctionType();
37353735
assert(config.resultIndices->getNumIndices() == 1 &&
37363736
"Only single result index is currently supported");
3737+
CanGenericSignature derivativeCanGenSig;
3738+
if (auto derivativeGenSig = config.derivativeGenericSignature)
3739+
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
37373740
auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType(
37383741
config.parameterIndices, *config.resultIndices->getIndices().begin(),
37393742
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
3740-
derivativeFnType->getSubstGenericSignature());
3743+
derivativeCanGenSig);
37413744
assert(!origDerivativeFnType->getExtInfo().hasContext());
37423745

37433746
auto loc = derivativeFn->getLocation();

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,29 @@ static SILValue reapplyFunctionConversion(
461461
context.addDifferentiableFunctionInstToWorklist(dfi);
462462
newArgs.back() = dfi;
463463
}
464-
// If new function's generic signature is specified, use it to create
465-
// substitution map for reapplied `partial_apply` instruction.
466-
auto substMap = !newFuncGenSig
467-
? pai->getSubstitutionMap()
468-
: SubstitutionMap::get(
469-
newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
470-
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
464+
// Compute substitution map for reapplying `partial_apply`.
465+
// - If reapplied functoin is not polymorphic, use empty substitution map
466+
// regardless of the original `partial_apply`'s substitution map.
467+
// - This case is triggered for reapplying `partial_apply` where `newFunc`
468+
// is a `differentiability_witness_function` where the witness generic
469+
// signature has all concrete parameters while the original function's
470+
// generic signature does not. In this case, the original function type
471+
// is polymorphic while derivative function types are not (specialized
472+
// with concrete types from same-type requirements).
473+
// - Otherwise, if `newFuncGenSig` is not specified, use the original
474+
// `partial_apply`'s substitution map.
475+
// - Otherwise, if `newFuncGenSig` is specified, combine it with the
476+
// original `partial_apply`'s substitution map.
477+
SubstitutionMap substMap;
478+
if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) {
479+
if (!newFuncGenSig) {
480+
substMap = pai->getSubstitutionMap();
481+
} else {
482+
substMap = SubstitutionMap::get(
483+
newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
484+
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
485+
}
486+
}
471487
return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
472488
ParameterConvention::Direct_Guaranteed);
473489
}
@@ -796,14 +812,16 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
796812
original->getName(), AutoDiffDerivativeFunctionKind::VJP,
797813
witness->getConfig()))
798814
.str();
799-
auto vjpGenericSig = getDerivativeGenericSignature(witness, original);
800-
auto *vjpGenericEnv = vjpGenericSig
801-
? vjpGenericSig->getGenericEnvironment()
802-
: nullptr;
815+
CanGenericSignature vjpCanGenSig;
816+
if (auto jvpGenSig = witness->getDerivativeGenericSignature())
817+
vjpCanGenSig = jvpGenSig->getCanonicalSignature();
818+
GenericEnvironment *vjpGenericEnv = nullptr;
819+
if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
820+
vjpGenericEnv = vjpCanGenSig->getGenericEnvironment();
803821
auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
804822
indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP,
805823
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
806-
vjpGenericSig,
824+
vjpCanGenSig,
807825
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
808826

809827
SILOptFunctionBuilder fb(context.getTransform());
@@ -839,14 +857,16 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
839857
original->getName(), AutoDiffDerivativeFunctionKind::JVP,
840858
witness->getConfig()))
841859
.str();
842-
auto jvpGenericSig = getDerivativeGenericSignature(witness, original);
843-
auto *jvpGenericEnv = jvpGenericSig
844-
? jvpGenericSig->getGenericEnvironment()
845-
: nullptr;
860+
CanGenericSignature jvpCanGenSig;
861+
if (auto jvpGenSig = witness->getDerivativeGenericSignature())
862+
jvpCanGenSig = jvpGenSig->getCanonicalSignature();
863+
GenericEnvironment *jvpGenericEnv = nullptr;
864+
if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
865+
jvpGenericEnv = jvpCanGenSig->getGenericEnvironment();
846866
auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
847867
indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::JVP,
848868
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
849-
jvpGenericSig,
869+
jvpCanGenSig,
850870
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
851871

852872
SILOptFunctionBuilder fb(context.getTransform());

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,6 @@ void forEachApplyDirectResult(
8181
resultCallback(result);
8282
}
8383

84-
/// Returns the canonical derivative generic signature for the given witness
85-
/// and original function.
86-
/// - Return the witness derivative generic signature if it exists.
87-
/// - Otherwise, return the original function's generic signature.
88-
CanGenericSignature
89-
getDerivativeGenericSignature(SILDifferentiabilityWitness *witness,
90-
SILFunction *original) {
91-
if (auto sig = witness->getDerivativeGenericSignature())
92-
return sig->getCanonicalSignature();
93-
return original->getLoweredFunctionType()->getSubstGenericSignature();
94-
}
95-
9684
void collectAllFormalResultsInTypeOrder(SILFunction &function,
9785
SmallVectorImpl<SILValue> &results) {
9886
SILFunctionConventions convs(function.getLoweredFunctionType(),

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ SILType JVPEmitter::remapSILTypeInDifferential(SILType ty) {
285285
}
286286

287287
Optional<VectorSpace> JVPEmitter::getTangentSpace(CanType type) {
288+
// Use witness generic signature to remap types.
289+
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
290+
type = witnessGenSig->getCanonicalTypeInContext(type);
288291
return type->getAutoDiffAssociatedTangentSpace(
289292
LookUpConformanceInModule(getModule().getSwiftModule()));
290293
}
@@ -1015,6 +1018,14 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
10151018
auto *original = witness->getOriginalFunction();
10161019
auto *jvp = witness->getJVP();
10171020
auto origTy = original->getLoweredFunctionType();
1021+
// Get witness generic signature for remapping types.
1022+
// Witness generic signature may have more requirements than JVP generic
1023+
// signature: when witness generic signature has same-type requirements
1024+
// binding all generic parameters to concrete types, JVP function type uses
1025+
// all the concrete types and JVP generic signature is null.
1026+
CanGenericSignature witnessCanGenSig;
1027+
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
1028+
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
10181029
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
10191030

10201031
// Parameters of the differential are:
@@ -1028,16 +1039,20 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
10281039
auto indices = witness->getSILAutoDiffIndices();
10291040

10301041
// Add differential results.
1031-
auto origResInfo = origTy->getResults()[indices.source];
1042+
auto origResult = origTy->getResults()[indices.source];
1043+
origResult = origResult.getWithInterfaceType(
1044+
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
10321045
dfResults.push_back(
1033-
SILResultInfo(origResInfo.getInterfaceType()
1046+
SILResultInfo(origResult.getInterfaceType()
10341047
->getAutoDiffAssociatedTangentSpace(lookupConformance)
10351048
->getCanonicalType(),
1036-
origResInfo.getConvention()));
1049+
origResult.getConvention()));
10371050

10381051
// Add differential parameters for the requested wrt parameters.
10391052
for (auto i : indices.parameters->getIndices()) {
10401053
auto origParam = origParams[i];
1054+
origParam = origParam.getWithInterfaceType(
1055+
origParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
10411056
dfParams.push_back(SILParameterInfo(
10421057
origParam.getInterfaceType()
10431058
->getAutoDiffAssociatedTangentSpace(lookupConformance)
@@ -1059,7 +1074,11 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
10591074
original->getName(), AutoDiffLinearMapKind::Differential,
10601075
witness->getConfig()))
10611076
.str();
1062-
auto diffGenericSig = getDerivativeGenericSignature(witness, original);
1077+
// Set differential generic signature equal to JVP generic signature.
1078+
// Do not use witness generic signature, which may have same-type requirements
1079+
// binding all generic parameters to concrete types.
1080+
auto diffGenericSig =
1081+
jvp->getLoweredFunctionType()->getSubstGenericSignature();
10631082
auto *diffGenericEnv =
10641083
diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr;
10651084
auto diffType = SILFunctionType::get(

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ SILType PullbackEmitter::remapType(SILType ty) {
162162
}
163163

164164
Optional<VectorSpace> PullbackEmitter::getTangentSpace(CanType type) {
165+
// Use witness generic signature to remap types.
166+
if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature())
167+
type = witnessGenSig->getCanonicalTypeInContext(type);
165168
return type->getAutoDiffAssociatedTangentSpace(
166169
LookUpConformanceInModule(getModule().getSwiftModule()));
167170
}

lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,20 @@ VJPEmitter::VJPEmitter(ADContext &context, SILFunction *original,
7575
SILFunction *VJPEmitter::createEmptyPullback() {
7676
auto &module = context.getModule();
7777
auto origTy = original->getLoweredFunctionType();
78+
// Get witness generic signature for remapping types.
79+
// Witness generic signature may have more requirements than VJP generic
80+
// signature: when witness generic signature has same-type requirements
81+
// binding all generic parameters to concrete types, VJP function type uses
82+
// all the concrete types and VJP generic signature is null.
83+
CanGenericSignature witnessCanGenSig;
84+
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
85+
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
7886
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
7987

8088
// Given a type, returns its formal SIL parameter info.
8189
auto getTangentParameterInfoForOriginalResult =
8290
[&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
83-
Lowering::AbstractionPattern pattern(
84-
vjp->getLoweredFunctionType()->getSubstGenericSignature(), tanType);
91+
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
8592
auto &tl = context.getTypeConverter().getTypeLowering(
8693
pattern, tanType, TypeExpansionContext::minimal());
8794
ParameterConvention conv;
@@ -105,8 +112,7 @@ SILFunction *VJPEmitter::createEmptyPullback() {
105112
// Given a type, returns its formal SIL result info.
106113
auto getTangentResultInfoForOriginalParameter =
107114
[&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
108-
Lowering::AbstractionPattern pattern(
109-
vjp->getLoweredFunctionType()->getSubstGenericSignature(), tanType);
115+
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
110116
auto &tl = context.getTypeConverter().getTypeLowering(
111117
pattern, tanType, TypeExpansionContext::minimal());
112118
ResultConvention conv;
@@ -139,12 +145,14 @@ SILFunction *VJPEmitter::createEmptyPullback() {
139145
auto indices = witness->getSILAutoDiffIndices();
140146

141147
// Add pullback parameter for the seed.
142-
auto origResInfo = origTy->getResults()[indices.source];
148+
auto origResult = origTy->getResults()[indices.source];
149+
origResult = origResult.getWithInterfaceType(
150+
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
143151
pbParams.push_back(getTangentParameterInfoForOriginalResult(
144-
origResInfo.getInterfaceType()
152+
origResult.getInterfaceType()
145153
->getAutoDiffAssociatedTangentSpace(lookupConformance)
146154
->getCanonicalType(),
147-
origResInfo.getConvention()));
155+
origResult.getConvention()));
148156

149157
// Accept a pullback struct in the pullback parameter list. This is the
150158
// returned pullback's closure context.
@@ -156,6 +164,8 @@ SILFunction *VJPEmitter::createEmptyPullback() {
156164
// Add pullback results for the requested wrt parameters.
157165
for (auto i : indices.parameters->getIndices()) {
158166
auto origParam = origParams[i];
167+
origParam = origParam.getWithInterfaceType(
168+
origParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
159169
adjResults.push_back(getTangentResultInfoForOriginalParameter(
160170
origParam.getInterfaceType()
161171
->getAutoDiffAssociatedTangentSpace(lookupConformance)
@@ -169,7 +179,10 @@ SILFunction *VJPEmitter::createEmptyPullback() {
169179
original->getName(), AutoDiffLinearMapKind::Pullback,
170180
witness->getConfig()))
171181
.str();
172-
auto pbGenericSig = getDerivativeGenericSignature(witness, original);
182+
// Set pullback generic signature equal to VJP generic signature.
183+
// Do not use witness generic signature, which may have same-type requirements
184+
// binding all generic parameters to concrete types.
185+
auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
173186
auto *pbGenericEnv =
174187
pbGenericSig ? pbGenericSig->getGenericEnvironment() : nullptr;
175188
auto pbType = SILFunctionType::get(

0 commit comments

Comments
 (0)