Skip to content

[AutoDiff] Improve getConstrainedDerivativeGenericSignature helper. #29621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace swift {
class AnyFunctionType;
class TupleType;
struct SILAutoDiffIndices;

class SILFunctionType;

/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
Expand Down Expand Up @@ -240,6 +240,18 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
SmallVectorImpl<Type> &results,
bool reverseCurryLevels = false);

/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
/// Returns the "constrained" derivative generic signature given:
/// - An original SIL function type.
/// - Differentiability parameter indices.
/// - A possibly "unconstrained" derivative generic signature.
GenericSignature
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig);

} // end namespace autodiff

} // end namespace swift
Expand Down Expand Up @@ -342,7 +354,6 @@ namespace swift {

class ASTContext;
class AnyFunctionType;
class SILFunctionType;
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
enum class SILLinkage : uint8_t;

Expand Down
27 changes: 27 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
//
//===----------------------------------------------------------------------===//

#include "swift/AST/ASTContext.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"

using namespace swift;
Expand Down Expand Up @@ -67,6 +69,31 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
}
}

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig) {
if (!derivativeGenSig)
derivativeGenSig = originalFnTy->getSubstGenericSignature();
if (!derivativeGenSig)
return nullptr;
// Constrain all differentiability parameters to `Differentiable`.
auto &ctx = originalFnTy->getASTContext();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
SmallVector<Requirement, 4> requirements;
for (unsigned paramIdx : diffParamIndices->getIndices()) {
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
requirements.push_back(req);
}
return evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
/*addedGenericParams*/ {},
std::move(requirements)},
nullptr);
}

// SWIFT_ENABLE_TENSORFLOW
// Not-yet-upstreamed `tensorflow` branch additions are below.

Expand Down
39 changes: 6 additions & 33 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,35 +226,6 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
isGenericSignatureImplied(), getASTContext());
}

// Returns the canonical generic signature for an autodiff derivative function
// given an existing derivative function generic signature. All differentiation
// parameters are constrained to conform to `Differentiable`.
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
CanGenericSignature derivativeFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
IndexSubset *parameterIndices, ModuleDecl *module) {
if (!derivativeFnGenSig)
return nullptr;
auto &ctx = module->getASTContext();
GenericSignatureBuilder builder(ctx);

// Add derivative function generic signature.
builder.addGenericSignature(derivativeFnGenSig);
// Constrain all wrt parameters to conform to `Differentiable`.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
for (unsigned paramIdx : parameterIndices->getIndices()) {
auto paramType = originalParameters[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
builder.addRequirement(req, source, module);
}
return std::move(builder)
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
->getCanonicalSignature();
}

CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
Expand Down Expand Up @@ -293,8 +264,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
// Get the canonical derivative function generic signature.
if (!derivativeFnGenSig)
derivativeFnGenSig = getSubstGenericSignature();
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, derivativeFnGenSig)
.getCanonicalSignature();

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult =
Expand Down Expand Up @@ -456,8 +428,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
// Get the canonical transpose function generic signature.
if (!genSig)
genSig = getSubstGenericSignature();
genSig = getAutoDiffDerivativeFunctionGenericSignature(
genSig, getParameters(), parameterIndices, &TC.M);
genSig = autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, genSig)
.getCanonicalSignature();

// Given a type, returns its formal SIL parameter info.
auto getParameterInfoForOriginalResult =
Expand Down
35 changes: 1 addition & 34 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,39 +84,6 @@ template <typename T> static inline void debugDump(T &v) {
<< v << "\n==== END DEBUG DUMP ====\n");
}

/// Returns the "constrained" derivative generic signature given:
/// - An original SIL function type.
/// - A wrt parameter index subset.
/// - A possibly uncanonical derivative generic signature (optional).
/// - Additional derivative requirements (optional).
/// The constrained derivative generic signature constrains all wrt parameters
/// to conform to `Differentiable`.
static GenericSignature
getConstrainedDerivativeGenericSignature(CanSILFunctionType originalFnTy,
IndexSubset *paramIndexSet,
GenericSignature derivativeGenSig) {
if (!derivativeGenSig)
derivativeGenSig = originalFnTy->getSubstGenericSignature();
if (!derivativeGenSig)
return nullptr;
// Constrain all wrt parameters to `Differentiable`.
auto &ctx = derivativeGenSig->getASTContext();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
SmallVector<Requirement, 4> requirements;
for (unsigned paramIdx : paramIndexSet->getIndices()) {
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
requirements.push_back(req);
}
return evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
/*addedGenericParams*/ {},
std::move(requirements)},
nullptr);
}

namespace {

class DifferentiationTransformer {
Expand Down Expand Up @@ -597,7 +564,7 @@ emitDerivativeFunctionReference(
invoker.getIndirectDifferentiation()
.second->getDerivativeGenericSignature();
auto derivativeConstrainedGenSig =
getConstrainedDerivativeGenericSignature(
autodiff::getConstrainedDerivativeGenericSignature(
originalFn->getLoweredFunctionType(), desiredParameterIndices,
contextualDerivativeGenSig);
minimalWitness = SILDifferentiabilityWitness::createDefinition(
Expand Down