Skip to content

Commit 8a17205

Browse files
authored
[AutoDiff] Requestify @differentiable attribute parameter indices. (#28017)
Requestify `@differentiable` attribute parameter indices resolution: `DifferentiableAttributeParameterIndicesRequest`. This is necessary for type-checking `@differentiable` attributes in non-primary files. The previous workaround (`TypeChecker::checkDeclDifferentiableAttributes`) no longer works because `TypeChecker::validateDecl` has been replaced with `InterfaceTypeRequest::evaluate`. Currently, all `@differentiable` attribute type-checking (`AttributeChecker::visitDifferentiableAttr`) has been moved into `DifferentiableAttributeParameterIndicesRequest::evaluate`. In the future, consider splitting `DifferentiableAttributeParameterIndicesRequest` into multiple requests for the following functionality: - `DifferentiableAttr::getJVPFunction` - `DifferentiableAttr::getVJPFunction` - `DifferentiableAttr::getDerivativeGenericSignature`
1 parent 2f1306e commit 8a17205

12 files changed

+208
-133
lines changed

include/swift/AST/Attr.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,7 @@ class DifferentiableAttr final
15401540
private llvm::TrailingObjects<DifferentiableAttr,
15411541
ParsedAutoDiffParameter> {
15421542
friend TrailingObjects;
1543+
friend class DifferentiableAttributeParameterIndicesRequest;
15431544

15441545
/// The declaration on which the `@differentiable` attribute is declared.
15451546
Decl *OriginalDeclaration = nullptr;
@@ -1558,7 +1559,8 @@ class DifferentiableAttr final
15581559
/// specified.
15591560
FuncDecl *VJPFunction = nullptr;
15601561
/// The differentiation parameters' indices, resolved by the type checker.
1561-
IndexSubset *ParameterIndices = nullptr;
1562+
/// The bit stores whether the parameter indices have been computed.
1563+
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
15621564
/// The trailing where clause (optional).
15631565
TrailingWhereClause *WhereClause = nullptr;
15641566
/// The generic signature for autodiff derivative functions. Resolved by the
@@ -1575,9 +1577,9 @@ class DifferentiableAttr final
15751577
Optional<DeclNameWithLoc> vjp,
15761578
TrailingWhereClause *clause);
15771579

1578-
explicit DifferentiableAttr(Decl *original, bool implicit,
1579-
SourceLoc atLoc, SourceRange baseRange,
1580-
bool linear, IndexSubset *indices,
1580+
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1581+
SourceRange baseRange, bool linear,
1582+
IndexSubset *indices,
15811583
Optional<DeclNameWithLoc> jvp,
15821584
Optional<DeclNameWithLoc> vjp,
15831585
GenericSignature derivativeGenericSignature);
@@ -1611,12 +1613,9 @@ class DifferentiableAttr final
16111613
/// registered VJP.
16121614
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
16131615

1614-
IndexSubset *getParameterIndices() const {
1615-
return ParameterIndices;
1616-
}
1617-
void setParameterIndices(IndexSubset *pi) {
1618-
ParameterIndices = pi;
1619-
}
1616+
bool hasComputedParameterIndices() const;
1617+
IndexSubset *getParameterIndices() const;
1618+
void setParameterIndices(IndexSubset *paramIndices);
16201619

16211620
/// The parsed differentiation parameters, i.e. the list of parameters
16221621
/// specified in 'wrt:'.
@@ -1647,8 +1646,7 @@ class DifferentiableAttr final
16471646
void setVJPFunction(FuncDecl *decl);
16481647

16491648
bool parametersMatch(const DifferentiableAttr &other) const {
1650-
assert(ParameterIndices && other.ParameterIndices);
1651-
return ParameterIndices == other.ParameterIndices;
1649+
return getParameterIndices() == other.getParameterIndices();
16521650
}
16531651

16541652
/// Get the derivative generic environment for the given `@differentiable`

include/swift/AST/TypeCheckRequests.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,29 @@ class SynthesizeDefaultInitRequest
16341634
bool isCached() const { return true; }
16351635
};
16361636

1637+
// SWIFT_ENABLE_TENSORFLOW
1638+
class DifferentiableAttributeParameterIndicesRequest :
1639+
public SimpleRequest<DifferentiableAttributeParameterIndicesRequest,
1640+
IndexSubset *(DifferentiableAttr *, Decl *),
1641+
CacheKind::SeparatelyCached> {
1642+
public:
1643+
using SimpleRequest::SimpleRequest;
1644+
1645+
private:
1646+
friend SimpleRequest;
1647+
1648+
// Evaluation.
1649+
llvm::Expected<IndexSubset *>
1650+
evaluate(Evaluator &evaluator, DifferentiableAttr *attr, Decl *decl) const;
1651+
1652+
public:
1653+
// Separate caching.
1654+
bool isCached() const { return true; }
1655+
Optional<IndexSubset *> getCachedResult() const;
1656+
void cacheResult(IndexSubset *value) const;
1657+
};
1658+
// SWIFT_ENABLE_TENSORFLOW END
1659+
16371660
// Allow AnyValue to compare two Type values, even though Type doesn't
16381661
// support ==.
16391662
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ SWIFT_REQUEST(TypeChecker, ClassAncestryFlagsRequest,
3131
AncestryFlags(ClassDecl *), Cached, NoLocationInfo)
3232
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
3333
Type(AssociatedTypeDecl *), Cached, NoLocationInfo)
34+
// SWIFT_ENABLE_TENSORFLOW
35+
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeParameterIndicesRequest,
36+
IndexSubset *(DifferentiableAttr *, Decl *),
37+
SeparatelyCached, NoLocationInfo)
38+
// SWIFT_ENABLE_TENSORFLOW END
3439
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
3540
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,
3641
NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "swift/AST/GenericSignatureBuilder.h"
2525
#include "swift/AST/Module.h"
2626
#include "swift/AST/TypeRepr.h"
27+
// SWIFT_ENABLE_TENSORFLOW
28+
#include "swift/AST/TypeCheckRequests.h"
2729
#include "swift/AST/Types.h"
2830
// SWIFT_ENABLE_TENSORFLOW
2931
#include "swift/AST/ParameterList.h"
@@ -1461,9 +1463,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
14611463
Optional<DeclNameWithLoc> vjp,
14621464
GenericSignature derivativeGenSig)
14631465
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1464-
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1465-
ParameterIndices(indices) {
1466+
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
14661467
setOriginalDeclaration(original);
1468+
setParameterIndices(indices);
14671469
setDerivativeGenericSignature(derivativeGenSig);
14681470
}
14691471

@@ -1503,6 +1505,31 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
15031505
OriginalDeclaration = decl;
15041506
}
15051507

1508+
bool DifferentiableAttr::hasComputedParameterIndices() const {
1509+
return ParameterIndicesAndBit.getInt();
1510+
}
1511+
1512+
IndexSubset *DifferentiableAttr::getParameterIndices() const {
1513+
assert(getOriginalDeclaration() &&
1514+
"Original declaration must have been resolved");
1515+
auto &ctx = getOriginalDeclaration()->getASTContext();
1516+
return evaluateOrDefault(
1517+
ctx.evaluator,
1518+
DifferentiableAttributeParameterIndicesRequest{
1519+
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1520+
nullptr);
1521+
}
1522+
1523+
void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
1524+
assert(getOriginalDeclaration() &&
1525+
"Original declaration must have been resolved");
1526+
auto &ctx = getOriginalDeclaration()->getASTContext();
1527+
ctx.evaluator.cacheOutput(
1528+
DifferentiableAttributeParameterIndicesRequest{
1529+
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1530+
std::move(paramIndices));
1531+
}
1532+
15061533
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15071534
JVPFunction = decl;
15081535
if (decl && !JVP)

lib/AST/TypeCheckRequests.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,3 +1072,21 @@ void swift::simple_display(llvm::raw_ostream &out,
10721072
out << "precedence group " << desc.ident << " at ";
10731073
desc.nameLoc.print(out, desc.dc->getASTContext().SourceMgr);
10741074
}
1075+
1076+
//----------------------------------------------------------------------------//
1077+
// DifferentiableAttributeParameterIndicesRequest computation.
1078+
//----------------------------------------------------------------------------//
1079+
1080+
Optional<IndexSubset *>
1081+
DifferentiableAttributeParameterIndicesRequest::getCachedResult() const {
1082+
auto *attr = std::get<0>(getStorage());
1083+
if (attr->hasComputedParameterIndices())
1084+
return attr->ParameterIndicesAndBit.getPointer();
1085+
return None;
1086+
}
1087+
1088+
void DifferentiableAttributeParameterIndicesRequest::cacheResult(
1089+
IndexSubset *parameterIndices) const {
1090+
auto *attr = std::get<0>(getStorage());
1091+
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
1092+
}

lib/Parse/ParseDecl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,6 +3390,7 @@ static void setOriginalFunctionInDifferentiableAttributes(
33903390
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
33913391
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
33923392
}
3393+
// SWIFT_ENABLE_TENSORFLOW END
33933394

33943395
/// Parse a single syntactic declaration and return a list of decl
33953396
/// ASTs. This can return multiple results for var decls that bind to multiple
@@ -3836,6 +3837,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
38363837
// SWIFT_ENABLE_TENSORFLOW END
38373838
}
38383839
// SWIFT_ENABLE_TENSORFLOW
3840+
// Set original declaration in `@differentiable` attributes.
38393841
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
38403842
// SWIFT_ENABLE_TENSORFLOW END
38413843
}
@@ -5592,6 +5594,7 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
55925594
accessors.record(*this, PrimaryVar, Invalid);
55935595

55945596
// SWIFT_ENABLE_TENSORFLOW
5597+
// Set original declaration in `@differentiable` attributes.
55955598
for (auto *accessor : accessors.Accessors)
55965599
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
55975600
accessor);
@@ -5852,6 +5855,7 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
58525855
VD->setStatic(StaticLoc.isValid());
58535856
VD->getAttrs() = Attributes;
58545857
// SWIFT_ENABLE_TENSORFLOW
5858+
// Set original declaration in `@differentiable` attributes.
58555859
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
58565860
// SWIFT_ENABLE_TENSORFLOW END
58575861
setLocalDiscriminator(VD);
@@ -7109,6 +7113,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
71097113
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));
71107114

71117115
// SWIFT_ENABLE_TENSORFLOW
7116+
// Set original declaration in `@differentiable` attributes.
71127117
for (auto *accessor : accessors.Accessors)
71137118
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
71147119
accessor);

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
8383
!constant.autoDiffDerivativeFunctionIdentifier &&
8484
!constant.isStoredPropertyInitializer() &&
8585
!constant.isThunk()) {
86+
// NOTE: Validate `@differentiable` attributes on `AccessorDecl`s by calling
87+
// `getParameterIndices`. This is significant to prevent duplicate SIL
88+
// `[differentiable]` attribute generation: `getParameterIndices` deletes
89+
// `@differentiable` attributes whose original declaration is an
90+
// `AbstractStorageDecl`.
91+
if (isa<AccessorDecl>(decl))
92+
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
93+
(void)A->getParameterIndices();
8694
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
8795
// Get lowered argument indices.
8896
auto *paramIndices = A->getParameterIndices();

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -613,13 +613,12 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
613613
// call to the getter.
614614
if (member->getEffectiveAccess() > AccessLevel::Internal &&
615615
!member->getAttrs().hasAttribute<DifferentiableAttr>()) {
616-
(void)member->getAccessor(AccessorKind::Get)->getInterfaceType();
616+
auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);
617+
(void)getter->getInterfaceType();
617618
// If member or its getter already has a `@differentiable` attribute,
618619
// continue.
619620
if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
620-
member->getAccessor(AccessorKind::Get)
621-
->getAttrs()
622-
.hasAttribute<DifferentiableAttr>())
621+
getter->getAttrs().hasAttribute<DifferentiableAttr>())
623622
continue;
624623
GenericSignature derivativeGenSig = GenericSignature();
625624
// If the parent declaration context is an extension, the nominal type may
@@ -628,9 +627,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
628627
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
629628
derivativeGenSig = extDecl->getGenericSignature();
630629
auto *diffableAttr = DifferentiableAttr::create(
631-
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
632-
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
633-
derivativeGenSig);
630+
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
631+
/*linear*/ false, {}, None, None, derivativeGenSig);
634632
member->getAttrs().add(diffableAttr);
635633
// Set getter `@differentiable` attribute parameter indices.
636634
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));

0 commit comments

Comments
 (0)