Skip to content

Commit 7dd1ea2

Browse files
committed
Merge branch 'tensorflow-merge' of github.com:apple/swift into tensorflow-merge
2 parents 738f232 + 8a17205 commit 7dd1ea2

12 files changed

+208
-133
lines changed

include/swift/AST/Attr.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,7 @@ class DifferentiableAttr final
15401540
private llvm::TrailingObjects<DifferentiableAttr,
15411541
ParsedAutoDiffParameter> {
15421542
friend TrailingObjects;
1543+
friend class DifferentiableAttributeParameterIndicesRequest;
15431544

15441545
/// The declaration on which the `@differentiable` attribute is declared.
15451546
Decl *OriginalDeclaration = nullptr;
@@ -1558,7 +1559,8 @@ class DifferentiableAttr final
15581559
/// specified.
15591560
FuncDecl *VJPFunction = nullptr;
15601561
/// The differentiation parameters' indices, resolved by the type checker.
1561-
IndexSubset *ParameterIndices = nullptr;
1562+
/// The bit stores whether the parameter indices have been computed.
1563+
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
15621564
/// The trailing where clause (optional).
15631565
TrailingWhereClause *WhereClause = nullptr;
15641566
/// The generic signature for autodiff derivative functions. Resolved by the
@@ -1575,9 +1577,9 @@ class DifferentiableAttr final
15751577
Optional<DeclNameWithLoc> vjp,
15761578
TrailingWhereClause *clause);
15771579

1578-
explicit DifferentiableAttr(Decl *original, bool implicit,
1579-
SourceLoc atLoc, SourceRange baseRange,
1580-
bool linear, IndexSubset *indices,
1580+
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1581+
SourceRange baseRange, bool linear,
1582+
IndexSubset *indices,
15811583
Optional<DeclNameWithLoc> jvp,
15821584
Optional<DeclNameWithLoc> vjp,
15831585
GenericSignature derivativeGenericSignature);
@@ -1611,12 +1613,9 @@ class DifferentiableAttr final
16111613
/// registered VJP.
16121614
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
16131615

1614-
IndexSubset *getParameterIndices() const {
1615-
return ParameterIndices;
1616-
}
1617-
void setParameterIndices(IndexSubset *pi) {
1618-
ParameterIndices = pi;
1619-
}
1616+
bool hasComputedParameterIndices() const;
1617+
IndexSubset *getParameterIndices() const;
1618+
void setParameterIndices(IndexSubset *paramIndices);
16201619

16211620
/// The parsed differentiation parameters, i.e. the list of parameters
16221621
/// specified in 'wrt:'.
@@ -1647,8 +1646,7 @@ class DifferentiableAttr final
16471646
void setVJPFunction(FuncDecl *decl);
16481647

16491648
bool parametersMatch(const DifferentiableAttr &other) const {
1650-
assert(ParameterIndices && other.ParameterIndices);
1651-
return ParameterIndices == other.ParameterIndices;
1649+
return getParameterIndices() == other.getParameterIndices();
16521650
}
16531651

16541652
/// Get the derivative generic environment for the given `@differentiable`

include/swift/AST/TypeCheckRequests.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,29 @@ class CompareDeclSpecializationRequest
16541654
bool isCached() const { return true; }
16551655
};
16561656

1657+
// SWIFT_ENABLE_TENSORFLOW
1658+
class DifferentiableAttributeParameterIndicesRequest :
1659+
public SimpleRequest<DifferentiableAttributeParameterIndicesRequest,
1660+
IndexSubset *(DifferentiableAttr *, Decl *),
1661+
CacheKind::SeparatelyCached> {
1662+
public:
1663+
using SimpleRequest::SimpleRequest;
1664+
1665+
private:
1666+
friend SimpleRequest;
1667+
1668+
// Evaluation.
1669+
llvm::Expected<IndexSubset *>
1670+
evaluate(Evaluator &evaluator, DifferentiableAttr *attr, Decl *decl) const;
1671+
1672+
public:
1673+
// Separate caching.
1674+
bool isCached() const { return true; }
1675+
Optional<IndexSubset *> getCachedResult() const;
1676+
void cacheResult(IndexSubset *value) const;
1677+
};
1678+
// SWIFT_ENABLE_TENSORFLOW END
1679+
16571680
// Allow AnyValue to compare two Type values, even though Type doesn't
16581681
// support ==.
16591682
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ SWIFT_REQUEST(TypeChecker, CompareDeclSpecializationRequest,
3434
NoLocationInfo)
3535
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,
3636
Type(AssociatedTypeDecl *), Cached, NoLocationInfo)
37+
// SWIFT_ENABLE_TENSORFLOW
38+
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeParameterIndicesRequest,
39+
IndexSubset *(DifferentiableAttr *, Decl *),
40+
SeparatelyCached, NoLocationInfo)
41+
// SWIFT_ENABLE_TENSORFLOW END
3742
SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
3843
Type(KnownProtocolKind, const DeclContext *), SeparatelyCached,
3944
NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "swift/AST/GenericSignatureBuilder.h"
2525
#include "swift/AST/Module.h"
2626
#include "swift/AST/TypeRepr.h"
27+
// SWIFT_ENABLE_TENSORFLOW
28+
#include "swift/AST/TypeCheckRequests.h"
2729
#include "swift/AST/Types.h"
2830
// SWIFT_ENABLE_TENSORFLOW
2931
#include "swift/AST/ParameterList.h"
@@ -1461,9 +1463,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
14611463
Optional<DeclNameWithLoc> vjp,
14621464
GenericSignature derivativeGenSig)
14631465
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1464-
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1465-
ParameterIndices(indices) {
1466+
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
14661467
setOriginalDeclaration(original);
1468+
setParameterIndices(indices);
14671469
setDerivativeGenericSignature(derivativeGenSig);
14681470
}
14691471

@@ -1503,6 +1505,31 @@ void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
15031505
OriginalDeclaration = decl;
15041506
}
15051507

1508+
bool DifferentiableAttr::hasComputedParameterIndices() const {
1509+
return ParameterIndicesAndBit.getInt();
1510+
}
1511+
1512+
IndexSubset *DifferentiableAttr::getParameterIndices() const {
1513+
assert(getOriginalDeclaration() &&
1514+
"Original declaration must have been resolved");
1515+
auto &ctx = getOriginalDeclaration()->getASTContext();
1516+
return evaluateOrDefault(
1517+
ctx.evaluator,
1518+
DifferentiableAttributeParameterIndicesRequest{
1519+
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1520+
nullptr);
1521+
}
1522+
1523+
void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
1524+
assert(getOriginalDeclaration() &&
1525+
"Original declaration must have been resolved");
1526+
auto &ctx = getOriginalDeclaration()->getASTContext();
1527+
ctx.evaluator.cacheOutput(
1528+
DifferentiableAttributeParameterIndicesRequest{
1529+
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1530+
std::move(paramIndices));
1531+
}
1532+
15061533
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15071534
JVPFunction = decl;
15081535
if (decl && !JVP)

lib/AST/TypeCheckRequests.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,3 +1069,21 @@ void swift::simple_display(llvm::raw_ostream &out,
10691069
out << "precedence group " << desc.ident << " at ";
10701070
desc.nameLoc.print(out, desc.dc->getASTContext().SourceMgr);
10711071
}
1072+
1073+
//----------------------------------------------------------------------------//
1074+
// DifferentiableAttributeParameterIndicesRequest computation.
1075+
//----------------------------------------------------------------------------//
1076+
1077+
Optional<IndexSubset *>
1078+
DifferentiableAttributeParameterIndicesRequest::getCachedResult() const {
1079+
auto *attr = std::get<0>(getStorage());
1080+
if (attr->hasComputedParameterIndices())
1081+
return attr->ParameterIndicesAndBit.getPointer();
1082+
return None;
1083+
}
1084+
1085+
void DifferentiableAttributeParameterIndicesRequest::cacheResult(
1086+
IndexSubset *parameterIndices) const {
1087+
auto *attr = std::get<0>(getStorage());
1088+
attr->ParameterIndicesAndBit.setPointerAndInt(parameterIndices, true);
1089+
}

lib/Parse/ParseDecl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,6 +3390,7 @@ static void setOriginalFunctionInDifferentiableAttributes(
33903390
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
33913391
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
33923392
}
3393+
// SWIFT_ENABLE_TENSORFLOW END
33933394

33943395
/// Parse a single syntactic declaration and return a list of decl
33953396
/// ASTs. This can return multiple results for var decls that bind to multiple
@@ -3836,6 +3837,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
38363837
// SWIFT_ENABLE_TENSORFLOW END
38373838
}
38383839
// SWIFT_ENABLE_TENSORFLOW
3840+
// Set original declaration in `@differentiable` attributes.
38393841
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
38403842
// SWIFT_ENABLE_TENSORFLOW END
38413843
}
@@ -5592,6 +5594,7 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
55925594
accessors.record(*this, PrimaryVar, Invalid);
55935595

55945596
// SWIFT_ENABLE_TENSORFLOW
5597+
// Set original declaration in `@differentiable` attributes.
55955598
for (auto *accessor : accessors.Accessors)
55965599
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
55975600
accessor);
@@ -5852,6 +5855,7 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
58525855
VD->setStatic(StaticLoc.isValid());
58535856
VD->getAttrs() = Attributes;
58545857
// SWIFT_ENABLE_TENSORFLOW
5858+
// Set original declaration in `@differentiable` attributes.
58555859
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
58565860
// SWIFT_ENABLE_TENSORFLOW END
58575861
setLocalDiscriminator(VD);
@@ -7109,6 +7113,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
71097113
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));
71107114

71117115
// SWIFT_ENABLE_TENSORFLOW
7116+
// Set original declaration in `@differentiable` attributes.
71127117
for (auto *accessor : accessors.Accessors)
71137118
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
71147119
accessor);

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
8383
!constant.autoDiffDerivativeFunctionIdentifier &&
8484
!constant.isStoredPropertyInitializer() &&
8585
!constant.isThunk()) {
86+
// NOTE: Validate `@differentiable` attributes on `AccessorDecl`s by calling
87+
// `getParameterIndices`. This is significant to prevent duplicate SIL
88+
// `[differentiable]` attribute generation: `getParameterIndices` deletes
89+
// `@differentiable` attributes whose original declaration is an
90+
// `AbstractStorageDecl`.
91+
if (isa<AccessorDecl>(decl))
92+
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
93+
(void)A->getParameterIndices();
8694
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
8795
// Get lowered argument indices.
8896
auto *paramIndices = A->getParameterIndices();

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -612,13 +612,12 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
612612
// call to the getter.
613613
if (member->getEffectiveAccess() > AccessLevel::Internal &&
614614
!member->getAttrs().hasAttribute<DifferentiableAttr>()) {
615-
(void)member->getAccessor(AccessorKind::Get)->getInterfaceType();
615+
auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);
616+
(void)getter->getInterfaceType();
616617
// If member or its getter already has a `@differentiable` attribute,
617618
// continue.
618619
if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
619-
member->getAccessor(AccessorKind::Get)
620-
->getAttrs()
621-
.hasAttribute<DifferentiableAttr>())
620+
getter->getAttrs().hasAttribute<DifferentiableAttr>())
622621
continue;
623622
GenericSignature derivativeGenSig = GenericSignature();
624623
// If the parent declaration context is an extension, the nominal type may
@@ -627,9 +626,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
627626
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
628627
derivativeGenSig = extDecl->getGenericSignature();
629628
auto *diffableAttr = DifferentiableAttr::create(
630-
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
631-
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
632-
derivativeGenSig);
629+
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
630+
/*linear*/ false, {}, None, None, derivativeGenSig);
633631
member->getAttrs().add(diffableAttr);
634632
// Set getter `@differentiable` attribute parameter indices.
635633
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));

0 commit comments

Comments
 (0)