Skip to content

[AutoDiff] Fix derivative generic signature same-type requirements. #28772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4466,7 +4466,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
/// This "constrained derivative generic signature" is used for
/// parameter/result type lowering. It is used as the actual generic signature
/// of the derivative function type iff the original function type has a
/// generic signature; otherwise, no derivative generic signature is used.
/// generic signature and not all generic parameters are bound to concrete
/// types. Otherwise, no derivative generic signature is used.
///
/// Other properties of the original function type are copied exactly:
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
Expand Down
4 changes: 3 additions & 1 deletion include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class SILDifferentiabilityWitness
/// The original function.
SILFunction *OriginalFunction;
/// The autodiff configuration: parameter indices, result indices, derivative
/// generic signature (optional).
/// generic signature (optional). The derivative generic signature may contain
/// same-type requirements such that all generic parameters are bound to
/// concrete types.
AutoDiffConfig Config;
/// The JVP (Jacobian-vector products) derivative function.
SILFunction *JVP;
Expand Down
8 changes: 0 additions & 8 deletions include/swift/SILOptimizer/Utils/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,6 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
void forEachApplyDirectResult(
ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback);

/// Returns the canonical derivative generic signature for the given witness
/// and original function.
/// - Return the witness derivative generic signature if it exists.
/// - Otherwise, return the original function's generic signature.
CanGenericSignature
getDerivativeGenericSignature(SILDifferentiabilityWitness *witness,
SILFunction *original);

/// Given a function, gathers all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
/// refer to result values in the body of the function, not at call sites.
Expand Down
11 changes: 8 additions & 3 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(

SmallVector<SILParameterInfo, 4> newParameters;
newParameters.reserve(getNumParameters());
newParameters.append(getParameters().begin(), getParameters().end());
for (auto &param : getParameters()) {
newParameters.push_back(param.getWithInterfaceType(
param.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
}
// Reabstraction thunks have a function-typed parameter (the function to
// reabstract) as their last parameter. Reabstraction thunk JVPs/VJPs have a
// `@differentiable` function-typed last parameter instead.
Expand All @@ -414,9 +417,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
ResultConvention::Owned});
// Derivative function type has a generic signature only if the original
// function type does.
// function type does, and if `derivativeFnGenSig` does not have all concrete
// generic parameters.
CanGenericSignature canGenSig;
if (getSubstGenericSignature())
if (getSubstGenericSignature() && derivativeFnGenSig &&
!derivativeFnGenSig->areAllParamsConcrete())
canGenSig = derivativeFnGenSig;
return SILFunctionType::get(canGenSig, getExtInfo(), getCoroutineKind(),
getCalleeConvention(), newParameters, getYields(),
Expand Down
7 changes: 4 additions & 3 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
assert((!AFD->getGenericSignature() || diffAttr->getDerivativeGenericSignature()) &&
"type-checking should resolve derivative generic signatures for "
"all functions with generic signatures");
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
diffAttr->getDerivativeGenericSignature()) &&
"Type-checking should resolve derivative generic signatures for "
"all original SIL functions with generic signatures");
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
Expand Down
5 changes: 4 additions & 1 deletion lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3734,10 +3734,13 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
auto origFnType = original->getLoweredFunctionType();
assert(config.resultIndices->getNumIndices() == 1 &&
"Only single result index is currently supported");
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = config.derivativeGenericSignature)
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType(
config.parameterIndices, *config.resultIndices->getIndices().begin(),
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
derivativeFnType->getSubstGenericSignature());
derivativeCanGenSig);
assert(!origDerivativeFnType->getExtInfo().hasContext());

auto loc = derivativeFn->getLocation();
Expand Down
54 changes: 37 additions & 17 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,29 @@ static SILValue reapplyFunctionConversion(
context.addDifferentiableFunctionInstToWorklist(dfi);
newArgs.back() = dfi;
}
// If new function's generic signature is specified, use it to create
// substitution map for reapplied `partial_apply` instruction.
auto substMap = !newFuncGenSig
? pai->getSubstitutionMap()
: SubstitutionMap::get(
newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
// Compute substitution map for reapplying `partial_apply`.
// - If reapplied functoin is not polymorphic, use empty substitution map
// regardless of the original `partial_apply`'s substitution map.
// - This case is triggered for reapplying `partial_apply` where `newFunc`
// is a `differentiability_witness_function` where the witness generic
// signature has all concrete parameters while the original function's
// generic signature does not. In this case, the original function type
// is polymorphic while derivative function types are not (specialized
// with concrete types from same-type requirements).
// - Otherwise, if `newFuncGenSig` is not specified, use the original
// `partial_apply`'s substitution map.
// - Otherwise, if `newFuncGenSig` is specified, combine it with the
// original `partial_apply`'s substitution map.
SubstitutionMap substMap;
if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) {
if (!newFuncGenSig) {
substMap = pai->getSubstitutionMap();
} else {
substMap = SubstitutionMap::get(
newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
}
}
return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
ParameterConvention::Direct_Guaranteed);
}
Expand Down Expand Up @@ -796,14 +812,16 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
original->getName(), AutoDiffDerivativeFunctionKind::VJP,
witness->getConfig()))
.str();
auto vjpGenericSig = getDerivativeGenericSignature(witness, original);
auto *vjpGenericEnv = vjpGenericSig
? vjpGenericSig->getGenericEnvironment()
: nullptr;
CanGenericSignature vjpCanGenSig;
if (auto jvpGenSig = witness->getDerivativeGenericSignature())
vjpCanGenSig = jvpGenSig->getCanonicalSignature();
GenericEnvironment *vjpGenericEnv = nullptr;
if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
vjpGenericEnv = vjpCanGenSig->getGenericEnvironment();
auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
vjpGenericSig,
vjpCanGenSig,
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);

SILOptFunctionBuilder fb(context.getTransform());
Expand Down Expand Up @@ -839,14 +857,16 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
original->getName(), AutoDiffDerivativeFunctionKind::JVP,
witness->getConfig()))
.str();
auto jvpGenericSig = getDerivativeGenericSignature(witness, original);
auto *jvpGenericEnv = jvpGenericSig
? jvpGenericSig->getGenericEnvironment()
: nullptr;
CanGenericSignature jvpCanGenSig;
if (auto jvpGenSig = witness->getDerivativeGenericSignature())
jvpCanGenSig = jvpGenSig->getCanonicalSignature();
GenericEnvironment *jvpGenericEnv = nullptr;
if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
jvpGenericEnv = jvpCanGenSig->getGenericEnvironment();
auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::JVP,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
jvpGenericSig,
jvpCanGenSig,
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);

SILOptFunctionBuilder fb(context.getTransform());
Expand Down
12 changes: 0 additions & 12 deletions lib/SILOptimizer/Utils/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,6 @@ void forEachApplyDirectResult(
resultCallback(result);
}

/// Returns the canonical derivative generic signature for the given witness
/// and original function.
/// - Return the witness derivative generic signature if it exists.
/// - Otherwise, return the original function's generic signature.
CanGenericSignature
getDerivativeGenericSignature(SILDifferentiabilityWitness *witness,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this helper is removed because the witness derivative generic signature can be used directly, now that derivative generic signatures are propagated correctly. There's no need to fallback to the original generic signature.

SILFunction *original) {
if (auto sig = witness->getDerivativeGenericSignature())
return sig->getCanonicalSignature();
return original->getLoweredFunctionType()->getSubstGenericSignature();
}

void collectAllFormalResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results) {
SILFunctionConventions convs(function.getLoweredFunctionType(),
Expand Down
27 changes: 23 additions & 4 deletions lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ SILType JVPEmitter::remapSILTypeInDifferential(SILType ty) {
}

Optional<VectorSpace> JVPEmitter::getTangentSpace(CanType type) {
// Use witness generic signature to remap types.
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
type = witnessGenSig->getCanonicalTypeInContext(type);
return type->getAutoDiffAssociatedTangentSpace(
LookUpConformanceInModule(getModule().getSwiftModule()));
}
Expand Down Expand Up @@ -1015,6 +1018,14 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
auto *original = witness->getOriginalFunction();
auto *jvp = witness->getJVP();
auto origTy = original->getLoweredFunctionType();
// Get witness generic signature for remapping types.
// Witness generic signature may have more requirements than JVP generic
// signature: when witness generic signature has same-type requirements
// binding all generic parameters to concrete types, JVP function type uses
// all the concrete types and JVP generic signature is null.
CanGenericSignature witnessCanGenSig;
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

// Parameters of the differential are:
Expand All @@ -1028,16 +1039,20 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
auto indices = witness->getSILAutoDiffIndices();

// Add differential results.
auto origResInfo = origTy->getResults()[indices.source];
auto origResult = origTy->getResults()[indices.source];
origResult = origResult.getWithInterfaceType(
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
dfResults.push_back(
SILResultInfo(origResInfo.getInterfaceType()
SILResultInfo(origResult.getInterfaceType()
->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType(),
origResInfo.getConvention()));
origResult.getConvention()));

// Add differential parameters for the requested wrt parameters.
for (auto i : indices.parameters->getIndices()) {
auto origParam = origParams[i];
origParam = origParam.getWithInterfaceType(
origParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
dfParams.push_back(SILParameterInfo(
origParam.getInterfaceType()
->getAutoDiffAssociatedTangentSpace(lookupConformance)
Expand All @@ -1059,7 +1074,11 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
original->getName(), AutoDiffLinearMapKind::Differential,
witness->getConfig()))
.str();
auto diffGenericSig = getDerivativeGenericSignature(witness, original);
// Set differential generic signature equal to JVP generic signature.
// Do not use witness generic signature, which may have same-type requirements
// binding all generic parameters to concrete types.
auto diffGenericSig =
jvp->getLoweredFunctionType()->getSubstGenericSignature();
auto *diffGenericEnv =
diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr;
auto diffType = SILFunctionType::get(
Expand Down
3 changes: 3 additions & 0 deletions lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ SILType PullbackEmitter::remapType(SILType ty) {
}

Optional<VectorSpace> PullbackEmitter::getTangentSpace(CanType type) {
// Use witness generic signature to remap types.
if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature())
type = witnessGenSig->getCanonicalTypeInContext(type);
return type->getAutoDiffAssociatedTangentSpace(
LookUpConformanceInModule(getModule().getSwiftModule()));
}
Expand Down
29 changes: 21 additions & 8 deletions lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,20 @@ VJPEmitter::VJPEmitter(ADContext &context, SILFunction *original,
SILFunction *VJPEmitter::createEmptyPullback() {
auto &module = context.getModule();
auto origTy = original->getLoweredFunctionType();
// Get witness generic signature for remapping types.
// Witness generic signature may have more requirements than VJP generic
// signature: when witness generic signature has same-type requirements
// binding all generic parameters to concrete types, VJP function type uses
// all the concrete types and VJP generic signature is null.
CanGenericSignature witnessCanGenSig;
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult =
[&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
Lowering::AbstractionPattern pattern(
vjp->getLoweredFunctionType()->getSubstGenericSignature(), tanType);
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
auto &tl = context.getTypeConverter().getTypeLowering(
pattern, tanType, TypeExpansionContext::minimal());
ParameterConvention conv;
Expand All @@ -105,8 +112,7 @@ SILFunction *VJPEmitter::createEmptyPullback() {
// Given a type, returns its formal SIL result info.
auto getTangentResultInfoForOriginalParameter =
[&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
Lowering::AbstractionPattern pattern(
vjp->getLoweredFunctionType()->getSubstGenericSignature(), tanType);
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
auto &tl = context.getTypeConverter().getTypeLowering(
pattern, tanType, TypeExpansionContext::minimal());
ResultConvention conv;
Expand Down Expand Up @@ -139,12 +145,14 @@ SILFunction *VJPEmitter::createEmptyPullback() {
auto indices = witness->getSILAutoDiffIndices();

// Add pullback parameter for the seed.
auto origResInfo = origTy->getResults()[indices.source];
auto origResult = origTy->getResults()[indices.source];
origResult = origResult.getWithInterfaceType(
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
pbParams.push_back(getTangentParameterInfoForOriginalResult(
origResInfo.getInterfaceType()
origResult.getInterfaceType()
->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType(),
origResInfo.getConvention()));
origResult.getConvention()));

// Accept a pullback struct in the pullback parameter list. This is the
// returned pullback's closure context.
Expand All @@ -156,6 +164,8 @@ SILFunction *VJPEmitter::createEmptyPullback() {
// Add pullback results for the requested wrt parameters.
for (auto i : indices.parameters->getIndices()) {
auto origParam = origParams[i];
origParam = origParam.getWithInterfaceType(
origParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
adjResults.push_back(getTangentResultInfoForOriginalParameter(
origParam.getInterfaceType()
->getAutoDiffAssociatedTangentSpace(lookupConformance)
Expand All @@ -169,7 +179,10 @@ SILFunction *VJPEmitter::createEmptyPullback() {
original->getName(), AutoDiffLinearMapKind::Pullback,
witness->getConfig()))
.str();
auto pbGenericSig = getDerivativeGenericSignature(witness, original);
// Set pullback generic signature equal to VJP generic signature.
// Do not use witness generic signature, which may have same-type requirements
// binding all generic parameters to concrete types.
auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
auto *pbGenericEnv =
pbGenericSig ? pbGenericSig->getGenericEnvironment() : nullptr;
auto pbType = SILFunctionType::get(
Expand Down
Loading