Skip to content

[AutoDiff] Plumb witness derivative generic signatures through SILGen. #40063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/swift/AST/Witness.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class Witness {
ConcreteDeclRef declRef;
GenericEnvironment *syntheticEnvironment;
SubstitutionMap reqToSyntheticEnvSubs;
/// The derivative generic signature, when the requirement is a derivative
/// function.
GenericSignature derivativeGenSig;
};

llvm::PointerUnion<ValueDecl *, StoredWitness *> storage;
Expand Down Expand Up @@ -124,7 +127,8 @@ class Witness {
static Witness forDeserialized(ValueDecl *decl,
SubstitutionMap substitutions) {
// TODO: It's probably a good idea to have a separate 'deserialized' bit.
return Witness(decl, substitutions, nullptr, SubstitutionMap());
return Witness(
decl, substitutions, nullptr, SubstitutionMap(), CanGenericSignature());
}

/// Create a witness that requires substitutions.
Expand All @@ -138,10 +142,14 @@ class Witness {
///
/// \param reqToSyntheticEnvSubs The mapping from the interface types of the
/// requirement into the interface types of the synthetic environment.
///
/// \param derivativeGenSig The derivative generic signature, when the
/// requirement is a derivative function.
Witness(ValueDecl *decl,
SubstitutionMap substitutions,
GenericEnvironment *syntheticEnv,
SubstitutionMap reqToSyntheticEnvSubs);
SubstitutionMap reqToSyntheticEnvSubs,
GenericSignature derivativeGenSig);

/// Retrieve the witness declaration reference, which includes the
/// substitutions needed to use the witness from the synthetic environment
Expand Down Expand Up @@ -183,6 +191,13 @@ class Witness {
return {};
}

/// Retrieve the derivative generic signature.
GenericSignature getDerivativeGenericSignature() const {
if (auto *storedWitness = storage.dyn_cast<StoredWitness *>())
return storedWitness->derivativeGenSig;
return GenericSignature();
}

SWIFT_DEBUG_DUMP;

void dump(llvm::raw_ostream &out) const;
Expand Down
9 changes: 6 additions & 3 deletions lib/AST/ProtocolConformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ using namespace swift;

Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
GenericEnvironment *syntheticEnv,
SubstitutionMap reqToSynthesizedEnvSubs) {
SubstitutionMap reqToSynthesizedEnvSubs,
GenericSignature derivativeGenSig) {
if (!syntheticEnv && substitutions.empty() &&
reqToSynthesizedEnvSubs.empty()) {
storage = decl;
Expand All @@ -53,7 +54,8 @@ Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
auto declRef = ConcreteDeclRef(decl, substitutions);
auto storedMem = ctx.Allocate(sizeof(StoredWitness), alignof(StoredWitness));
auto stored = new (storedMem) StoredWitness{declRef, syntheticEnv,
reqToSynthesizedEnvSubs};
reqToSynthesizedEnvSubs,
derivativeGenSig};

storage = stored;
}
Expand Down Expand Up @@ -892,7 +894,8 @@ NormalProtocolConformance::getWitnessUncached(ValueDecl *requirement) const {
}

Witness SelfProtocolConformance::getWitness(ValueDecl *requirement) const {
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap());
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap(),
GenericSignature());
}

ConcreteDeclRef
Expand Down
159 changes: 75 additions & 84 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
return buildGenericSignature(ctx, sig, {}, reqs).getCanonicalSignature();
}

/// Given an original type, computes its tangent type for the purpose of
/// building a linear map using this type. When the original type is an
/// archetype or contains a type parameter, appends a new generic parameter and
/// a corresponding replacement type to the given containers.
static CanType getAutoDiffTangentTypeForLinearMap(
Type originalType,
LookupConformanceFn lookupConformance,
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
SmallVectorImpl<Type> &substReplacements,
ASTContext &context
) {
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
assert(maybeTanType && "Type does not have a tangent space?");
auto tanType = maybeTanType->getCanonicalType();
// If concrete, the tangent type is concrete.
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
return tanType;
// Otherwise, the tangent type is a new generic parameter substituted for the
// tangent type.
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
substGenericParams.push_back(gpType);
substReplacements.push_back(tanType);
return gpType;
}

/// Returns the differential type for the given original function type,
/// parameter indices, and result index.
static CanSILFunctionType getAutoDiffDifferentialType(
Expand Down Expand Up @@ -484,45 +510,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
SmallVector<SILParameterInfo, 8> differentialParams;
for (auto &param : diffParams) {
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanType = paramTan->getCanonicalType();
auto paramConv = getTangentParameterConvention(paramTanType,
param.getConvention());
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
differentialParams.push_back(
{paramTan->getCanonicalType(), paramConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
differentialParams.push_back({gpType, paramConv});
}
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
param.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramConv = getTangentParameterConvention(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
param.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
param.getConvention());
differentialParams.push_back({paramTanType, paramConv});
}
SmallVector<SILResultInfo, 1> differentialResults;
for (auto resultIndex : resultIndices->getIndices()) {
// Handle formal original result.
if (resultIndex < originalFnTy->getNumResults()) {
auto &result = originalResults[resultIndex];
auto resultTan =
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
auto resultTanType = resultTan->getCanonicalType();
auto resultConv =
getTangentResultConvention(resultTanType, result.getConvention());
if (!resultTanType->hasArchetype() &&
!resultTanType->hasTypeParameter()) {
differentialResults.push_back(
{resultTan->getCanonicalType(), resultConv});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(resultTanType);
differentialResults.push_back({gpType, resultConv});
}
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
result.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto resultConv = getTangentResultConvention(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
result.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
result.getConvention());
differentialResults.push_back({resultTanType, resultConv});
continue;
}
// Handle original `inout` parameter.
Expand All @@ -537,11 +550,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
if (parameterIndices->contains(paramIndex))
continue;
auto inoutParam = originalFnTy->getParameters()[paramIndex];
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
inoutParam.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
differentialResults.push_back(
{paramTan->getCanonicalType(), ResultConvention::Indirect});
{inoutParamTanType, ResultConvention::Indirect});
}

SubstitutionMap substitutions;
Expand Down Expand Up @@ -648,23 +661,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
// Handle formal original result.
if (resultIndex < originalFnTy->getNumResults()) {
auto &origRes = originalResults[resultIndex];
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
auto resultTanType = resultTan->getCanonicalType();
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
resultTanType, origRes.getConvention());
if (!resultTanType->hasArchetype() &&
!resultTanType->hasTypeParameter()) {
auto resultTanType = resultTan->getCanonicalType();
pullbackParams.push_back({resultTanType, paramTanConvention});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(resultTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
origRes.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto paramConv = getTangentParameterConventionForOriginalResult(
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
origRes.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
origRes.getConvention());
pullbackParams.push_back({resultTanType, paramConv});
continue;
}
// Handle original `inout` parameter.
Expand All @@ -674,28 +680,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
auto inoutParam = originalFnTy->getParameters()[paramIndex];
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
// The pullback parameter convention depends on whether the original `inout`
// paramater is a differentiability parameter.
// - If yes, the pullback parameter convention is `@inout`.
// - If no, the pullback parameter convention is `@in_guaranteed`.
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
inoutParam.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
auto paramTanConvention = isWrtInoutParameter
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
auto paramTanType = paramTan->getCanonicalType();
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackParams.push_back(
SILParameterInfo(paramTanType, paramTanConvention));
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
}

// Collect pullback results.
Expand All @@ -707,21 +703,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
// and always appear as pullback parameters.
if (param.isIndirectInOut())
continue;
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanType = paramTan->getCanonicalType();
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
param.getInterfaceType(), lookupConformance,
substGenericParams, substReplacements, ctx);
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
paramTanType, param.getConvention());
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackResults.push_back({paramTanType, resultTanConvention});
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackResults.push_back({gpType, resultTanConvention});
}
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
param.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getCanonicalType(),
param.getConvention());
pullbackResults.push_back({paramTanType, resultTanConvention});
}
SubstitutionMap substitutions;
if (!substGenericParams.empty()) {
Expand Down
3 changes: 2 additions & 1 deletion lib/SIL/IR/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2631,7 +2631,8 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
derivativeId->getParameterIndices(), derivativeId->getKind(),
LookUpConformanceInModule(&M));
LookUpConformanceInModule(&M),
derivativeId->getDerivativeGenericSignature());
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
}

Expand Down
20 changes: 18 additions & 2 deletions lib/SILGen/SILGenType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
if (!reqAccessor) {
if (auto witness = asDerived().getWitness(reqDecl)) {
return addMethodImplementation(
requirementRef, requirementRef.withDecl(witness.getDecl()),
requirementRef, getWitnessRef(requirementRef, witness),
witness);
}

Expand Down Expand Up @@ -444,7 +444,8 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
witnessStorage->getSynthesizedAccessor(reqAccessor->getAccessorKind());

return addMethodImplementation(
requirementRef, requirementRef.withDecl(witnessAccessor), witness);
requirementRef, getWitnessRef(requirementRef, witnessAccessor),
witness);
}

private:
Expand All @@ -458,6 +459,21 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
asDerived().addMethodImplementation(requirementRef, witnessRef,
isFree, witness);
}

SILDeclRef getWitnessRef(SILDeclRef requirementRef, Witness witness) {
auto witnessRef = requirementRef.withDecl(witness.getDecl());
// If the requirement/witness is a derivative function, we need to
// substitute the witness's derivative generic signature in its derivative
// function identifier.
if (requirementRef.isAutoDiffDerivativeFunction()) {
auto *reqrRerivativeId = requirementRef.getDerivativeFunctionIdentifier();
auto *witnessDerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
reqrRerivativeId->getKind(), reqrRerivativeId->getParameterIndices(),
witness.getDerivativeGenericSignature(), witnessRef.getASTContext());
witnessRef = witnessRef.asAutoDiffDerivativeFunction(witnessDerivativeId);
}
return witnessRef;
}
};

static IsSerialized_t isConformanceSerialized(RootProtocolConformance *conf) {
Expand Down
Loading