Skip to content

[AutoDiff upstream] Add @differentiable declaration attribute type-checking. #29231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace swift {
class DeclContext;
class DefaultArgumentInitializer;
class DerivativeAttr;
class DifferentiableAttr;
class ExtensionDecl;
class ForeignRepresentationInfo;
class FuncDecl;
Expand Down Expand Up @@ -287,6 +288,11 @@ class ASTContext final {
/// Cached mapping from types to their associated tangent spaces.
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;

/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
/// diagnose duplicate `@differentiable` attributes for the same key.
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
DifferentiableAttrs;

/// Cache of `@derivative` attributes keyed by parameter indices and
/// derivative function kind. Used to diagnose duplicate `@derivative`
/// attributes for the same key.
Expand Down
22 changes: 15 additions & 7 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,7 @@ class DifferentiableAttr final
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
friend TrailingObjects;
friend class DifferentiableAttributeTypeCheckRequest;

/// The declaration on which the `@differentiable` attribute is declared.
/// May not be a valid declaration for `@differentiable` attributes.
Expand All @@ -1667,7 +1668,12 @@ class DifferentiableAttr final
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The bit stores whether the parameter indices have been computed.
///
/// Note: it is necessary to use a bit instead of `nullptr` parameter indices
/// to represent "parameter indices not yet type-checked" because invalid
/// attributes have `nullptr` parameter indices but have been type-checked.
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
Expand Down Expand Up @@ -1723,12 +1729,14 @@ class DifferentiableAttr final
/// registered VJP.
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}
private:
/// Returns true if the given `@differentiable` attribute has been
/// type-checked.
bool hasBeenTypeChecked() const;

public:
IndexSubset *getParameterIndices() const;
void setParameterIndices(IndexSubset *parameterIndices);

/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
Expand Down
20 changes: 20 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5736,6 +5736,20 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
private:
ParameterList *Params;

/// The generation at which we last loaded derivative function configurations.
unsigned DerivativeFunctionConfigGeneration = 0;
/// Prepare to traverse the list of derivative function configurations.
void prepareDerivativeFunctionConfigurations();

/// A uniqued list of derivative function configurations.
/// - `@differentiable` and `@derivative` attribute type-checking is
/// responsible for populating derivative function configurations specified
/// in the current module.
/// - Module loading is responsible for populating derivative function
/// configurations from imported modules.
struct DerivativeFunctionConfigurationList;
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

protected:
// If a function has a body at all, we have either a parsed body AST node or
// we have saved the end location of the unparsed body.
Expand Down Expand Up @@ -6060,6 +6074,12 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// constructor.
bool hasDynamicSelfResult() const;

/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);

using DeclContext::operator new;
using Decl::getASTContext;
};
Expand Down
56 changes: 56 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2904,6 +2904,62 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
"containing type %0 does not conform to protocol %1",
(DeclName, DeclName))

// @differentiable
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
"attribute for transpose registration instead", ())
ERROR(differentiable_attr_overload_not_found,none,
"%0 does not have expected type %1", (DeclNameRef, Type))
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
// mention "same generic requirements".
ERROR(differentiable_attr_duplicate,none,
"duplicate '@differentiable' attribute with same parameters", ())
NOTE(differentiable_attr_duplicate_note,none,
"other attribute declared here", ())
ERROR(differentiable_attr_function_not_same_type_context,none,
"%0 is not defined in the current type context", (DeclNameRef))
ERROR(differentiable_attr_derivative_not_function,none,
"registered derivative %0 must be a 'func' declaration", (DeclNameRef))
ERROR(differentiable_attr_class_derivative_not_final,none,
"class member derivative must be final", ())
ERROR(differentiable_attr_invalid_access,none,
"derivative function %0 is required to either be public or "
"'@usableFromInline' because the original function %1 is public or "
"'@usableFromInline'", (DeclNameRef, DeclName))
ERROR(differentiable_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
ERROR(differentiable_attr_protocol_req_where_clause,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'where' clause", ())
ERROR(differentiable_attr_protocol_req_assoc_func,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
"'@differentiable' attribute on stored property cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
"'@differentiable' attribute cannot be declared on class methods "
"returning 'Self'", ())
// TODO(TF-654): Remove when differentiation supports class initializers.
ERROR(differentiable_attr_class_init_not_yet_supported,none,
"'@differentiable' attribute does not yet support class initializers",
())
ERROR(differentiable_attr_empty_where_clause,none,
"empty 'where' clause in '@differentiable' attribute", ())
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
"'where' clause is valid only when original function is generic %0",
(DeclName))
ERROR(differentiable_attr_layout_req_unsupported,none,
"'@differentiable' attribute does not yet support layout requirements",
())
ERROR(overriding_decl_missing_differentiable_attr,none,
"overriding declaration is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

// @derivative
ERROR(derivative_attr_expected_result_tuple,none,
"'@derivative(of:)' attribute requires function to return a two-element "
Expand Down
36 changes: 35 additions & 1 deletion include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ struct WhereClauseOwner {

/// The source of the where clause, which can be a generic parameter list
/// or a declaration that can have a where clause.
llvm::PointerUnion<GenericParamList *, TrailingWhereClause *, SpecializeAttr *> source;
llvm::PointerUnion<GenericParamList *, TrailingWhereClause *,
SpecializeAttr *, DifferentiableAttr *>
source;

WhereClauseOwner(GenericContext *genCtx);
WhereClauseOwner(AssociatedTypeDecl *atd);
Expand All @@ -386,6 +388,9 @@ struct WhereClauseOwner {
WhereClauseOwner(DeclContext *dc, SpecializeAttr *attr)
: dc(dc), source(attr) {}

WhereClauseOwner(DeclContext *dc, DifferentiableAttr *attr)
: dc(dc), source(attr) {}

SourceLoc getLoc() const;

friend hash_code hash_value(const WhereClauseOwner &owner) {
Expand Down Expand Up @@ -2020,6 +2025,35 @@ class PatternTypeRequest
}
};

/// Type-checks a `@differentiable` attribute and returns the resolved parameter
/// indices on success. On failure, emits diagnostics and returns `nullptr`.
///
/// Currently, this request resolves other `@differentiable` attribute
/// components but mutates them in place:
/// - `JVPFunction`
/// - `VJPFunction`
/// - `DerivativeGenericSignature`
class DifferentiableAttributeTypeCheckRequest
: public SimpleRequest<DifferentiableAttributeTypeCheckRequest,
IndexSubset *(DifferentiableAttr *),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, DifferentiableAttributeTypeCheckRequest takes a DifferentiableAttr *and returns an IndexSubset * representing parameter indices. It also resolves two FuncDecl * (JVPFunction and VJPFunction) and a GenericSignature (DerivativeGenericSignature), mutating them in-place in the DifferentiableAttr *.

This works fine for now. I thought of two ways to refactor the request to avoid mutation:

  • Make DifferentiableAttributeTypeCheckRequest return a tuple/struct of the resolved components.
    • This caused many request cycles, since DifferentiableAttr::get{JVPFunction,VJPFunction,DerivativeGenericSignature} now all trigger the request. I haven't debugged further.
  • Make an individual request for resolving each @differentiable attribute component: original AbstractFuncDecl(s), parameter indices, JVP/VJP FuncDecl (if specified), derivative GenericSignature (if specified).
    • I haven't tried this. It seems like a fair amount of work, and the upside is unclear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

History: #28017 moved @differentiable attribute type-checking logic from AttributeChecker::visitDifferentiableAttr to DifferentiableAttributeTypeCheckRequest::evaluate.

The current request approach (return IndexSubset * but mutate other components in-place) is the first one that worked.

CacheKind::SeparatelyCached> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using CacheKind::Cached seems preferable so that DifferentiableAttr is truly stateless, adhering to the request evaluator caching vision.

I tried changing CacheKind::SeparatelyCached to CacheKind::Cached (some time ago on tensorflow branch), but it did not seem to work: @differentiable attributes in non-primary-files were left unchecked.

I'm not sure why CacheKind::Cached did not work. I did notice that quite a few other requests (e.g. InterfaceTypeRequest) used CacheKind::SeparatelyCached for some reason - perhaps that's just for request cycle breaking?

public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
llvm::Expected<IndexSubset *> evaluate(Evaluator &evaluator,
DifferentiableAttr *attr) const;

public:
// Separate caching.
bool isCached() const { return true; }
Optional<IndexSubset *> getCachedResult() const;
void cacheResult(IndexSubset *value) const;
};

// Allow AnyValue to compare two Type values, even though Type doesn't
// support ==.
template<>
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,
NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
IndexSubset *(DifferentiableAttr *),
SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DynamicallyReplacedDeclRequest,
ValueDecl *(ValueDecl *),
Cached, NoLocationInfo)
Expand Down
25 changes: 25 additions & 0 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "swift/AST/IndexSubset.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/TypeRepr.h"
#include "swift/AST/Types.h"
#include "swift/Basic/Defer.h"
Expand Down Expand Up @@ -1505,6 +1506,30 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
OriginalDeclaration = originalDeclaration;
}

bool DifferentiableAttr::hasBeenTypeChecked() const {
return ParameterIndicesAndBit.getInt();
}

IndexSubset *DifferentiableAttr::getParameterIndices() const {
assert(getOriginalDeclaration() &&
"Original declaration must have been resolved");
auto &ctx = getOriginalDeclaration()->getASTContext();
return evaluateOrDefault(ctx.evaluator,
DifferentiableAttributeTypeCheckRequest{
const_cast<DifferentiableAttr *>(this)},
nullptr);
}

void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
assert(getOriginalDeclaration() &&
"Original declaration must have been resolved");
auto &ctx = getOriginalDeclaration()->getASTContext();
ctx.evaluator.cacheOutput(
DifferentiableAttributeTypeCheckRequest{
const_cast<DifferentiableAttr *>(this)},
std::move(paramIndices));
}

void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
JVPFunction = decl;
if (decl && !JVP)
Expand Down
39 changes: 39 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6982,6 +6982,45 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
}

/// A uniqued list of derivative function configurations.
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
: public llvm::SetVector<AutoDiffConfig> {
// Necessary for `ASTContext` allocation.
void *operator new(
size_t bytes, ASTContext &ctx,
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
return ctx.Allocate(bytes, alignment);
}
};

void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
if (DerivativeFunctionConfigs)
return;
auto &ctx = getASTContext();
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
// Register an `ASTContext` cleanup calling the list destructor.
ctx.addCleanup([this]() {
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
});
}

ArrayRef<AutoDiffConfig>
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
prepareDerivativeFunctionConfigurations();
auto &ctx = getASTContext();
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
// TODO(TF-1100): Upstream derivative function configuration serialization
// logic.
}
return DerivativeFunctionConfigs->getArrayRef();
}

void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
AutoDiffConfig config) {
prepareDerivativeFunctionConfigurations();
DerivativeFunctionConfigs->insert(config);
}

FuncDecl *FuncDecl::createImpl(ASTContext &Context,
SourceLoc StaticLoc,
StaticSpellingKind StaticSpelling,
Expand Down
21 changes: 21 additions & 0 deletions lib/AST/TypeCheckRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ MutableArrayRef<RequirementRepr> WhereClauseOwner::getRequirements() const {
} else if (const auto attr = source.dyn_cast<SpecializeAttr *>()) {
if (auto whereClause = attr->getTrailingWhereClause())
return whereClause->getRequirements();
} else if (const auto attr = source.dyn_cast<DifferentiableAttr *>()) {
if (auto whereClause = attr->getWhereClause())
return whereClause->getRequirements();
} else if (const auto whereClause = source.get<TrailingWhereClause *>()) {
return whereClause->getRequirements();
}
Expand Down Expand Up @@ -1235,6 +1238,24 @@ void CallerSideDefaultArgExprRequest::cacheResult(Expr *expr) const {
defaultExpr->ContextOrCallerSideExpr = expr;
}

//----------------------------------------------------------------------------//
// DifferentiableAttributeTypeCheckRequest computation.
//----------------------------------------------------------------------------//

Optional<IndexSubset *>
DifferentiableAttributeTypeCheckRequest::getCachedResult() const {
auto *attr = std::get<0>(getStorage());
if (attr->hasBeenTypeChecked())
return attr->ParameterIndicesAndBit.getPointer();
return None;
}

void DifferentiableAttributeTypeCheckRequest::cacheResult(
IndexSubset *parameterIndices) const {
auto *attr = std::get<0>(getStorage());
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
}

//----------------------------------------------------------------------------//
// TypeCheckSourceFileRequest computation.
//----------------------------------------------------------------------------//
Expand Down
Loading