Skip to content

[AutoDiff] Improve getConstrainedDerivativeGenericSignature helper. #29620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
namespace swift {

class AnyFunctionType;
class SILFunctionType;
class TupleType;

/// A function type differentiability kind.
Expand Down Expand Up @@ -231,6 +232,18 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
SmallVectorImpl<Type> &results,
bool reverseCurryLevels = false);

/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
/// Returns the "constrained" derivative generic signature given:
/// - An original SIL function type.
/// - Differentiability parameter indices.
/// - A possibly "unconstrained" derivative generic signature.
GenericSignature
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't there a version for ASTFunctionType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - I guess an AnyFunctionType version hasn't been necessary for correctness.

I'm pretty sure the SILFunctionType version is necessary in the differentiation transform. So I don't believe an AnyFunctionType version can replace the SILFunctionType version.

IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig);

} // end namespace autodiff

} // end namespace swift
Expand Down
27 changes: 27 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//

#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"

using namespace swift;
Expand Down Expand Up @@ -67,6 +69,31 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
}
}

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig) {
if (!derivativeGenSig)
derivativeGenSig = originalFnTy->getSubstGenericSignature();
if (!derivativeGenSig)
return nullptr;
// Constrain all differentiability parameters to `Differentiable`.
auto &ctx = originalFnTy->getASTContext();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
SmallVector<Requirement, 4> requirements;
for (unsigned paramIdx : diffParamIndices->getIndices()) {
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
requirements.push_back(req);
}
return evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
/*addedGenericParams*/ {},
std::move(requirements)},
nullptr);
}

Type TangentSpace::getType() const {
switch (kind) {
case Kind::TangentVector:
Expand Down
34 changes: 2 additions & 32 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,36 +191,6 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
return nullptr;
}

// Returns the canonical generic signature for an autodiff derivative function
// given an existing derivative function generic signature. All
// differentiability parameters are required to conform to `Differentiable`.
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
CanGenericSignature derivativeFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
IndexSubset *parameterIndices, ModuleDecl *module) {
if (!derivativeFnGenSig)
return nullptr;
auto &ctx = module->getASTContext();
GenericSignatureBuilder builder(ctx);
// Add derivative function generic signature.
builder.addGenericSignature(derivativeFnGenSig);
// All differentiability parameters are required to conform to
// `Differentiable`.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
auto *differentiableProtocol =
ctx.getProtocol(KnownProtocolKind::Differentiable);
for (unsigned paramIdx : parameterIndices->getIndices()) {
auto paramType = originalParameters[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
differentiableProtocol->getDeclaredType());
builder.addRequirement(req, source, module);
}
return std::move(builder)
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
->getCanonicalSignature();
}

CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
Expand All @@ -243,8 +213,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
// Get the canonical derivative function generic signature.
if (!derivativeFnGenSig)
derivativeFnGenSig = getSubstGenericSignature();
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, derivativeFnGenSig).getCanonicalSignature();

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult =
Expand Down