Skip to content

[DNM] [AutoDiff upstream] @differentiable attribute type-checking. #29091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace swift {
class DeclContext;
class DefaultArgumentInitializer;
class DerivativeAttr;
class DifferentiableAttr;
class ExtensionDecl;
class ForeignRepresentationInfo;
class FuncDecl;
Expand Down Expand Up @@ -284,6 +285,14 @@ class ASTContext final {
/// across invocations of both the parser and the type-checker.
unsigned NextAutoClosureDiscriminator = 0;

/// Cached mapping from types to their associated tangent spaces.
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;

/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
/// diagnose duplicate `@differentiable` attributes for the same key.
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
DifferentiableAttrs;

/// Cache of `@derivative` attributes keyed by parameter indices and
/// derivative function kind. Used to diagnose duplicate `@derivative`
/// attributes for the same key.
Expand Down
35 changes: 35 additions & 0 deletions include/swift/AST/ASTScope.h
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,41 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
DeclConsumer) const override;
};

/// A `@differentiable` attribute scope.
///
/// This exists because `@differentiable` attribute may have a `where` clause
/// referring to generic parameters from some generic context.
class DifferentiableAttributeScope final : public ASTScopeImpl {
public:
DifferentiableAttr *const differentiableAttr;
ValueDecl *const attributedDeclaration;

DifferentiableAttributeScope(DifferentiableAttr *diffAttr, ValueDecl *decl)
: differentiableAttr(diffAttr), attributedDeclaration(decl) {}
virtual ~DifferentiableAttributeScope() {}

std::string getClassName() const override;
SourceRange
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
NullablePtr<const void> addressForPrinting() const override {
return differentiableAttr;
}

NullablePtr<AbstractStorageDecl>
getEnclosingAbstractStorageDecl() const override;

NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
return differentiableAttr;
}
NullablePtr<const void> getReferrent() const override;

protected:
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
DeclConsumer) const override;
bool doesContextMatchStartingContext(const DeclContext *) const override;
};

class SubscriptDeclScope final : public ASTScopeImpl {
public:
SubscriptDecl *const decl;
Expand Down
35 changes: 24 additions & 11 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1647,7 +1647,12 @@ class DifferentiableAttr final
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
friend TrailingObjects;
friend class DifferentiableAttributeTypeCheckRequest;

/// The declaration on which the `@differentiable` attribute is declared.
/// May not be a valid declaration for `@differentiable` attributes.
/// Resolved during parsing and deserialization.
Decl *OriginalDeclaration = nullptr;
/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed differentiability parameters specified in 'wrt:'.
Expand All @@ -1663,7 +1668,8 @@ class DifferentiableAttr final
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The bit stores whether the parameter indices have been computed.
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
Expand Down Expand Up @@ -1703,6 +1709,12 @@ class DifferentiableAttr final
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }

/// Sets the original declaration on which this attribute is declared.
/// Should only be used by parsing and deserialization.
void setOriginalDeclaration(Decl *originalDeclaration);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Expand All @@ -1713,12 +1725,14 @@ class DifferentiableAttr final
/// registered VJP.
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}
private:
/// Returns true if the given `@differentiable` attribute has been
/// type-checked.
bool hasBeenTypeChecked() const;

public:
IndexSubset *getParameterIndices() const;
void setParameterIndices(IndexSubset *parameterIndices);

/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
Expand Down Expand Up @@ -1755,10 +1769,9 @@ class DifferentiableAttr final

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitAssociatedFunctions` is true, omit printing associated functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
bool omitAssociatedFunctions = false) const;
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
bool omitDerivativeFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand Down
156 changes: 156 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cstdint>

#include "swift/AST/GenericSignature.h"
#include "swift/AST/Identifier.h"
#include "swift/AST/IndexSubset.h"
#include "swift/AST/Type.h"
Expand All @@ -28,6 +29,7 @@
namespace swift {

class AnyFunctionType;
class TupleType;

/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
Expand Down Expand Up @@ -70,6 +72,26 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature derivativeGenericSignature;

/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature derivativeGenericSignature)
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

void print(llvm::raw_ostream &s = llvm::outs()) const;
LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
"only for use within the debugger");
};

class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
Expand Down Expand Up @@ -133,6 +155,91 @@ class ParsedAutoDiffParameter {
}
};

/// The tangent space of a type.
///
/// For `Differentiable`-conforming types:
/// - The tangent space is the `TangentVector` associated type.
///
/// For tuple types:
/// - The tangent space is a tuple of the elements' tangent space types, for the
/// elements that have a tangent space.
///
/// For function types whose innermost result type has a tangent space:
/// - The tangent space is the same function type, replacing the innermost
/// result type with its tangent space type.
///
/// Other types have no tangent space.
class TangentSpace {
public:
/// A tangent space kind.
enum class Kind {
/// The `TangentVector` associated type of a `Differentiable`-conforming
/// type.
TangentVector,
/// A product of tangent spaces as a tuple.
Tuple,
/// A function type whose innermost result is the `TangentVector`
/// associated type of a `Differentiable`-conforming type.
Function
};

private:
Kind kind;
union Value {
// TangentVector
Type tangentVectorType;
// Tuple
TupleType *tupleType;
// Function
AnyFunctionType *functionType;

Value(Type tangentVectorType) : tangentVectorType(tangentVectorType) {}
Value(TupleType *tupleType) : tupleType(tupleType) {}
Value(AnyFunctionType *functionType) : functionType(functionType) {}
} value;

TangentSpace(Kind kind, Value value) : kind(kind), value(value) {}

public:
TangentSpace() = delete;

static TangentSpace getTangentVector(Type tangentVectorType) {
return {Kind::TangentVector, tangentVectorType};
}
static TangentSpace getTuple(TupleType *tupleTy) {
return {Kind::Tuple, tupleTy};
}
static TangentSpace getFunction(AnyFunctionType *fnTy) {
return {Kind::Function, fnTy};
}

bool isTangentVector() const { return kind == Kind::TangentVector; }
bool isTuple() const { return kind == Kind::Tuple; }

Kind getKind() const { return kind; }
Type getTangentVector() const {
assert(kind == Kind::TangentVector);
return value.tangentVectorType;
}
TupleType *getTuple() const {
assert(kind == Kind::Tuple);
return value.tupleType;
}
AnyFunctionType *getFunction() const {
assert(kind == Kind::Function);
return value.functionType;
}

/// Get the tangent space type.
Type getType() const;

/// Get the tangent space canonical type.
CanType getCanonicalType() const;

/// Get the underlying
NominalTypeDecl *getNominal() const;
};

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand All @@ -148,10 +255,59 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,

namespace llvm {

using swift::AutoDiffConfig;
using swift::AutoDiffDerivativeFunctionKind;
using swift::GenericSignature;
using swift::IndexSubset;

template <typename T> struct DenseMapInfo;

template <> struct DenseMapInfo<AutoDiffConfig> {
static AutoDiffConfig getEmptyKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
// The `derivativeGenericSignature` component must be `nullptr` so that
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
// an invalid pointer.
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
nullptr};
}

static AutoDiffConfig getTombstoneKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
// The `derivativeGenericSignature` component must be `nullptr` so that
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
// an invalid pointer.
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
nullptr};
}

static unsigned getHashValue(const AutoDiffConfig &Val) {
auto canGenSig =
Val.derivativeGenericSignature
? Val.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
unsigned combinedHash = hash_combine(
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
DenseMapInfo<GenericSignature>::getHashValue(canGenSig));
return combinedHash;
}

static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
auto lhsCanGenSig =
LHS.derivativeGenericSignature
? LHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
auto rhsCanGenSig =
RHS.derivativeGenericSignature
? RHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
return LHS.parameterIndices == RHS.parameterIndices &&
LHS.resultIndices == RHS.resultIndices &&
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
}
};

template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
static AutoDiffDerivativeFunctionKind getEmptyKey() {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
Expand Down
20 changes: 20 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5734,6 +5734,20 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
private:
ParameterList *Params;

/// The generation at which we last loaded derivative function configurations.
unsigned DerivativeFunctionConfigGeneration = 0;
/// Prepare to traverse the list of derivative function configurations.
void prepareDerivativeFunctionConfigurations();

/// A uniqued list of derivative function configurations.
/// - `@differentiable` and `@derivative` attribute type-checking is
/// responsible for populating derivative function configurations specified
/// in the current module.
/// - Module loading is responsible for populating derivative function
/// configurations from imported modules.
struct DerivativeFunctionConfigurationList;
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

protected:
// If a function has a body at all, we have either a parsed body AST node or
// we have saved the end location of the unparsed body.
Expand Down Expand Up @@ -6058,6 +6072,12 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// constructor.
bool hasDynamicSelfResult() const;

/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);

using DeclContext::operator new;
using Decl::getASTContext;
};
Expand Down
Loading