Skip to content

Commit 44d937d

Browse files
authored
[AutoDiff upstream] Add @differentiable declaration attribute type-checking. (swiftlang#29231)
The `@differentiable` attribute marks a function as differentiable. Example: ``` @differentiable(wrt: x, jvp: derivativeFoo where T: Differentiable) func id<T>(_ x: T) -> T { x } ``` The `@differentiable` attribute has an optional `wrt:` clause specifying the parameters that are differentiated "with respect to", i.e. the differentiability parameters. The differentiability parameters must conform to the `Differentiable` protocol. If the `wrt:` clause is unspecified, the differentiability parameters are currently inferred to be all parameters that conform to `Differentiable`. The `@differentiable` attribute also has optional `jvp:` and `vjp:` labels for registering derivative functions. These labels are deprecated in favor of the `@derivative` attribute and will be removed soon. The `@differentiable` attribute also has an optional `where` clause, specifying extra differentiability requirements for generic functions. The `@differentiable` attribute is gated by the `-enable-experimental-differentiable-programming` flag. Code changes: - Add `DifferentiableAttributeTypeCheckRequest`. - Currently, the request returns differentiability parameter indices, while also resolving `JVPFunction`, `VJPFunction`, and `DerivativeGenericSignature` and mutating them in-place in `DifferentiableAttr`. This was the simplest approach that worked without introducing request cycles. - Add "is type-checked" bit to `DifferentiableAttr`. - Alternatively, I tried changing `DifferentiableAttributeTypeCheckRequest` to use `CacheKind::Cache` instead of `CacheKind::SeparatelyCached`, but it did not seem to work: `@differentiable` attributes in non-primary-files were left unchecked. Type-checking rules (summary): - `@differentiable` attribute must be declared on a function-like "original" declaration: `func`, `init`, `subscript`, `var` (computed properties only). - Parsed differentiability parameters must be valid (if they exist). - Parsed `where` clause must be valid (if it exists). - Differentiability parameters must all conform to `Differentiable`. - Original result must all conform to `Differentiable`. - If JVP/VJP functions are specified, they must match the expected type. - `@differentiable(jvp:vjp:)` for derivative registration is deprecated in favor of `@derivative` attribute, and will be removed soon. - Duplicate `@differentiable` attributes with the same differentiability parameters are invalid. - For protocol requirements and class members with `@differentiable` attribute, conforming types and subclasses must have the same `@differentiable` attribute (or one with a superset of differentiability parameter indices) on implementing/overriding declarations.
1 parent 456ecaf commit 44d937d

14 files changed

+2253
-13
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace swift {
6464
class DeclContext;
6565
class DefaultArgumentInitializer;
6666
class DerivativeAttr;
67+
class DifferentiableAttr;
6768
class ExtensionDecl;
6869
class ForeignRepresentationInfo;
6970
class FuncDecl;
@@ -287,6 +288,11 @@ class ASTContext final {
287288
/// Cached mapping from types to their associated tangent spaces.
288289
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
289290

291+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
292+
/// diagnose duplicate `@differentiable` attributes for the same key.
293+
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
294+
DifferentiableAttrs;
295+
290296
/// Cache of `@derivative` attributes keyed by parameter indices and
291297
/// derivative function kind. Used to diagnose duplicate `@derivative`
292298
/// attributes for the same key.

include/swift/AST/Attr.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,7 @@ class DifferentiableAttr final
16381638
private llvm::TrailingObjects<DifferentiableAttr,
16391639
ParsedAutoDiffParameter> {
16401640
friend TrailingObjects;
1641+
friend class DifferentiableAttributeTypeCheckRequest;
16411642

16421643
/// The declaration on which the `@differentiable` attribute is declared.
16431644
/// May not be a valid declaration for `@differentiable` attributes.
@@ -1658,7 +1659,12 @@ class DifferentiableAttr final
16581659
/// specified.
16591660
FuncDecl *VJPFunction = nullptr;
16601661
/// The differentiability parameter indices, resolved by the type checker.
1661-
IndexSubset *ParameterIndices = nullptr;
1662+
/// The bit stores whether the parameter indices have been computed.
1663+
///
1664+
/// Note: it is necessary to use a bit instead of `nullptr` parameter indices
1665+
/// to represent "parameter indices not yet type-checked" because invalid
1666+
/// attributes have `nullptr` parameter indices but have been type-checked.
1667+
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
16621668
/// The trailing where clause (optional).
16631669
TrailingWhereClause *WhereClause = nullptr;
16641670
/// The generic signature for autodiff associated functions. Resolved by the
@@ -1714,12 +1720,14 @@ class DifferentiableAttr final
17141720
/// registered VJP.
17151721
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }
17161722

1717-
IndexSubset *getParameterIndices() const {
1718-
return ParameterIndices;
1719-
}
1720-
void setParameterIndices(IndexSubset *parameterIndices) {
1721-
ParameterIndices = parameterIndices;
1722-
}
1723+
private:
1724+
/// Returns true if the given `@differentiable` attribute has been
1725+
/// type-checked.
1726+
bool hasBeenTypeChecked() const;
1727+
1728+
public:
1729+
IndexSubset *getParameterIndices() const;
1730+
void setParameterIndices(IndexSubset *parameterIndices);
17231731

17241732
/// The parsed differentiability parameters, i.e. the list of parameters
17251733
/// specified in 'wrt:'.

include/swift/AST/Decl.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5724,6 +5724,20 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
57245724
private:
57255725
ParameterList *Params;
57265726

5727+
/// The generation at which we last loaded derivative function configurations.
5728+
unsigned DerivativeFunctionConfigGeneration = 0;
5729+
/// Prepare to traverse the list of derivative function configurations.
5730+
void prepareDerivativeFunctionConfigurations();
5731+
5732+
/// A uniqued list of derivative function configurations.
5733+
/// - `@differentiable` and `@derivative` attribute type-checking is
5734+
/// responsible for populating derivative function configurations specified
5735+
/// in the current module.
5736+
/// - Module loading is responsible for populating derivative function
5737+
/// configurations from imported modules.
5738+
struct DerivativeFunctionConfigurationList;
5739+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5740+
57275741
protected:
57285742
// If a function has a body at all, we have either a parsed body AST node or
57295743
// we have saved the end location of the unparsed body.
@@ -6048,6 +6062,12 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
60486062
/// constructor.
60496063
bool hasDynamicSelfResult() const;
60506064

6065+
/// Get all derivative function configurations.
6066+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
6067+
6068+
/// Add the given derivative function configuration.
6069+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
6070+
60516071
using DeclContext::operator new;
60526072
using Decl::getASTContext;
60536073
};

include/swift/AST/DiagnosticsSema.def

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,62 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29042904
"containing type %0 does not conform to protocol %1",
29052905
(DeclName, DeclName))
29062906

2907+
// @differentiable
2908+
ERROR(differentiable_attr_void_result,none,
2909+
"cannot differentiate void function %0", (DeclName))
2910+
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
2911+
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
2912+
"attribute for transpose registration instead", ())
2913+
ERROR(differentiable_attr_overload_not_found,none,
2914+
"%0 does not have expected type %1", (DeclNameRef, Type))
2915+
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
2916+
// mention "same generic requirements".
2917+
ERROR(differentiable_attr_duplicate,none,
2918+
"duplicate '@differentiable' attribute with same parameters", ())
2919+
NOTE(differentiable_attr_duplicate_note,none,
2920+
"other attribute declared here", ())
2921+
ERROR(differentiable_attr_function_not_same_type_context,none,
2922+
"%0 is not defined in the current type context", (DeclNameRef))
2923+
ERROR(differentiable_attr_derivative_not_function,none,
2924+
"registered derivative %0 must be a 'func' declaration", (DeclNameRef))
2925+
ERROR(differentiable_attr_class_derivative_not_final,none,
2926+
"class member derivative must be final", ())
2927+
ERROR(differentiable_attr_invalid_access,none,
2928+
"derivative function %0 is required to either be public or "
2929+
"'@usableFromInline' because the original function %1 is public or "
2930+
"'@usableFromInline'", (DeclNameRef, DeclName))
2931+
ERROR(differentiable_attr_result_not_differentiable,none,
2932+
"can only differentiate functions with results that conform to "
2933+
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2934+
ERROR(differentiable_attr_protocol_req_where_clause,none,
2935+
"'@differentiable' attribute on protocol requirement cannot specify "
2936+
"'where' clause", ())
2937+
ERROR(differentiable_attr_protocol_req_assoc_func,none,
2938+
"'@differentiable' attribute on protocol requirement cannot specify "
2939+
"'jvp:' or 'vjp:'", ())
2940+
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
2941+
"'@differentiable' attribute on stored property cannot specify "
2942+
"'jvp:' or 'vjp:'", ())
2943+
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
2944+
"'@differentiable' attribute cannot be declared on class methods "
2945+
"returning 'Self'", ())
2946+
// TODO(TF-654): Remove when differentiation supports class initializers.
2947+
ERROR(differentiable_attr_class_init_not_yet_supported,none,
2948+
"'@differentiable' attribute does not yet support class initializers",
2949+
())
2950+
ERROR(differentiable_attr_empty_where_clause,none,
2951+
"empty 'where' clause in '@differentiable' attribute", ())
2952+
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
2953+
"'where' clause is valid only when original function is generic %0",
2954+
(DeclName))
2955+
ERROR(differentiable_attr_layout_req_unsupported,none,
2956+
"'@differentiable' attribute does not yet support layout requirements",
2957+
())
2958+
ERROR(overriding_decl_missing_differentiable_attr,none,
2959+
"overriding declaration is missing attribute '%0'", (StringRef))
2960+
NOTE(protocol_witness_missing_differentiable_attr,none,
2961+
"candidate is missing attribute '%0'", (StringRef))
2962+
29072963
// @derivative
29082964
ERROR(derivative_attr_expected_result_tuple,none,
29092965
"'@derivative(of:)' attribute requires function to return a two-element "

include/swift/AST/TypeCheckRequests.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,9 @@ struct WhereClauseOwner {
375375

376376
/// The source of the where clause, which can be a generic parameter list
377377
/// or a declaration that can have a where clause.
378-
llvm::PointerUnion<GenericParamList *, TrailingWhereClause *, SpecializeAttr *> source;
378+
llvm::PointerUnion<GenericParamList *, TrailingWhereClause *,
379+
SpecializeAttr *, DifferentiableAttr *>
380+
source;
379381

380382
WhereClauseOwner(GenericContext *genCtx);
381383
WhereClauseOwner(AssociatedTypeDecl *atd);
@@ -386,6 +388,9 @@ struct WhereClauseOwner {
386388
WhereClauseOwner(DeclContext *dc, SpecializeAttr *attr)
387389
: dc(dc), source(attr) {}
388390

391+
WhereClauseOwner(DeclContext *dc, DifferentiableAttr *attr)
392+
: dc(dc), source(attr) {}
393+
389394
SourceLoc getLoc() const;
390395

391396
friend hash_code hash_value(const WhereClauseOwner &owner) {
@@ -2022,6 +2027,35 @@ class PatternTypeRequest
20222027
}
20232028
};
20242029

2030+
/// Type-checks a `@differentiable` attribute and returns the resolved parameter
2031+
/// indices on success. On failure, emits diagnostics and returns `nullptr`.
2032+
///
2033+
/// Currently, this request resolves other `@differentiable` attribute
2034+
/// components but mutates them in place:
2035+
/// - `JVPFunction`
2036+
/// - `VJPFunction`
2037+
/// - `DerivativeGenericSignature`
2038+
class DifferentiableAttributeTypeCheckRequest
2039+
: public SimpleRequest<DifferentiableAttributeTypeCheckRequest,
2040+
IndexSubset *(DifferentiableAttr *),
2041+
CacheKind::SeparatelyCached> {
2042+
public:
2043+
using SimpleRequest::SimpleRequest;
2044+
2045+
private:
2046+
friend SimpleRequest;
2047+
2048+
// Evaluation.
2049+
llvm::Expected<IndexSubset *> evaluate(Evaluator &evaluator,
2050+
DifferentiableAttr *attr) const;
2051+
2052+
public:
2053+
// Separate caching.
2054+
bool isCached() const { return true; }
2055+
Optional<IndexSubset *> getCachedResult() const;
2056+
void cacheResult(IndexSubset *value) const;
2057+
};
2058+
20252059
// Allow AnyValue to compare two Type values, even though Type doesn't
20262060
// support ==.
20272061
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
4343
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
4444
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,
4545
NoLocationInfo)
46+
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
47+
IndexSubset *(DifferentiableAttr *),
48+
SeparatelyCached, NoLocationInfo)
4649
SWIFT_REQUEST(TypeChecker, DynamicallyReplacedDeclRequest,
4750
ValueDecl *(ValueDecl *),
4851
Cached, NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "swift/AST/IndexSubset.h"
2424
#include "swift/AST/Module.h"
2525
#include "swift/AST/ParameterList.h"
26+
#include "swift/AST/TypeCheckRequests.h"
2627
#include "swift/AST/TypeRepr.h"
2728
#include "swift/AST/Types.h"
2829
#include "swift/Basic/Defer.h"
@@ -1507,6 +1508,30 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
15071508
OriginalDeclaration = originalDeclaration;
15081509
}
15091510

1511+
bool DifferentiableAttr::hasBeenTypeChecked() const {
1512+
return ParameterIndicesAndBit.getInt();
1513+
}
1514+
1515+
IndexSubset *DifferentiableAttr::getParameterIndices() const {
1516+
assert(getOriginalDeclaration() &&
1517+
"Original declaration must have been resolved");
1518+
auto &ctx = getOriginalDeclaration()->getASTContext();
1519+
return evaluateOrDefault(ctx.evaluator,
1520+
DifferentiableAttributeTypeCheckRequest{
1521+
const_cast<DifferentiableAttr *>(this)},
1522+
nullptr);
1523+
}
1524+
1525+
void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
1526+
assert(getOriginalDeclaration() &&
1527+
"Original declaration must have been resolved");
1528+
auto &ctx = getOriginalDeclaration()->getASTContext();
1529+
ctx.evaluator.cacheOutput(
1530+
DifferentiableAttributeTypeCheckRequest{
1531+
const_cast<DifferentiableAttr *>(this)},
1532+
std::move(paramIndices));
1533+
}
1534+
15101535
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15111536
JVPFunction = decl;
15121537
if (decl && !JVP)

lib/AST/Decl.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7007,6 +7007,45 @@ StringRef AbstractFunctionDecl::getInlinableBodyText(
70077007
return extractInlinableText(getASTContext().SourceMgr, body, scratch);
70087008
}
70097009

7010+
/// A uniqued list of derivative function configurations.
7011+
struct AbstractFunctionDecl::DerivativeFunctionConfigurationList
7012+
: public llvm::SetVector<AutoDiffConfig> {
7013+
// Necessary for `ASTContext` allocation.
7014+
void *operator new(
7015+
size_t bytes, ASTContext &ctx,
7016+
unsigned alignment = alignof(DerivativeFunctionConfigurationList)) {
7017+
return ctx.Allocate(bytes, alignment);
7018+
}
7019+
};
7020+
7021+
void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
7022+
if (DerivativeFunctionConfigs)
7023+
return;
7024+
auto &ctx = getASTContext();
7025+
DerivativeFunctionConfigs = new (ctx) DerivativeFunctionConfigurationList();
7026+
// Register an `ASTContext` cleanup calling the list destructor.
7027+
ctx.addCleanup([this]() {
7028+
this->DerivativeFunctionConfigs->~DerivativeFunctionConfigurationList();
7029+
});
7030+
}
7031+
7032+
ArrayRef<AutoDiffConfig>
7033+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
7034+
prepareDerivativeFunctionConfigurations();
7035+
auto &ctx = getASTContext();
7036+
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
7037+
// TODO(TF-1100): Upstream derivative function configuration serialization
7038+
// logic.
7039+
}
7040+
return DerivativeFunctionConfigs->getArrayRef();
7041+
}
7042+
7043+
void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
7044+
AutoDiffConfig config) {
7045+
prepareDerivativeFunctionConfigurations();
7046+
DerivativeFunctionConfigs->insert(config);
7047+
}
7048+
70107049
FuncDecl *FuncDecl::createImpl(ASTContext &Context,
70117050
SourceLoc StaticLoc,
70127051
StaticSpellingKind StaticSpelling,

lib/AST/TypeCheckRequests.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,9 @@ MutableArrayRef<RequirementRepr> WhereClauseOwner::getRequirements() const {
391391
} else if (const auto attr = source.dyn_cast<SpecializeAttr *>()) {
392392
if (auto whereClause = attr->getTrailingWhereClause())
393393
return whereClause->getRequirements();
394+
} else if (const auto attr = source.dyn_cast<DifferentiableAttr *>()) {
395+
if (auto whereClause = attr->getWhereClause())
396+
return whereClause->getRequirements();
394397
} else if (const auto whereClause = source.get<TrailingWhereClause *>()) {
395398
return whereClause->getRequirements();
396399
}
@@ -1241,6 +1244,24 @@ void CallerSideDefaultArgExprRequest::cacheResult(Expr *expr) const {
12411244
defaultExpr->ContextOrCallerSideExpr = expr;
12421245
}
12431246

1247+
//----------------------------------------------------------------------------//
1248+
// DifferentiableAttributeTypeCheckRequest computation.
1249+
//----------------------------------------------------------------------------//
1250+
1251+
Optional<IndexSubset *>
1252+
DifferentiableAttributeTypeCheckRequest::getCachedResult() const {
1253+
auto *attr = std::get<0>(getStorage());
1254+
if (attr->hasBeenTypeChecked())
1255+
return attr->ParameterIndicesAndBit.getPointer();
1256+
return None;
1257+
}
1258+
1259+
void DifferentiableAttributeTypeCheckRequest::cacheResult(
1260+
IndexSubset *parameterIndices) const {
1261+
auto *attr = std::get<0>(getStorage());
1262+
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
1263+
}
1264+
12441265
//----------------------------------------------------------------------------//
12451266
// TypeCheckSourceFileRequest computation.
12461267
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)