Skip to content

Commit 0b2e3ef

Browse files
committed
Merge branch 'tensorflow' of github.com:apple/swift into tensorflow-merge
``` Failing Tests (3): Swift(macosx-x86_64) :: AutoDiff/class_method_thunk/main.swift Swift(macosx-x86_64) :: AutoDiff/differentiable_attr_cross_module/main.swift Swift(macosx-x86_64) :: AutoDiff/silgen_thunking/main.swift ```
2 parents 2adc320 + 38bc954 commit 0b2e3ef

13 files changed

+258
-221
lines changed

include/swift/AST/Attr.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,8 @@ class DifferentiableAttr final
15411541
ParsedAutoDiffParameter> {
15421542
friend TrailingObjects;
15431543

1544+
/// The declaration on which the `@differentiable` attribute is declared.
1545+
Decl *OriginalDeclaration = nullptr;
15441546
/// Whether this function is linear.
15451547
bool Linear;
15461548
/// The number of parsed parameters specified in 'wrt:'.
@@ -1573,7 +1575,7 @@ class DifferentiableAttr final
15731575
Optional<DeclNameWithLoc> vjp,
15741576
TrailingWhereClause *clause);
15751577

1576-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1578+
explicit DifferentiableAttr(Decl *original, bool implicit,
15771579
SourceLoc atLoc, SourceRange baseRange,
15781580
bool linear, IndexSubset *indices,
15791581
Optional<DeclNameWithLoc> jvp,
@@ -1589,13 +1591,16 @@ class DifferentiableAttr final
15891591
Optional<DeclNameWithLoc> vjp,
15901592
TrailingWhereClause *clause);
15911593

1592-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1594+
static DifferentiableAttr *create(Decl *original, bool implicit,
15931595
SourceLoc atLoc, SourceRange baseRange,
15941596
bool linear, IndexSubset *indices,
15951597
Optional<DeclNameWithLoc> jvp,
15961598
Optional<DeclNameWithLoc> vjp,
15971599
GenericSignature derivativeGenSig);
15981600

1601+
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
1602+
void setOriginalDeclaration(Decl *decl);
1603+
15991604
/// Get the optional 'jvp:' function name and location.
16001605
/// Use this instead of `getJVPFunction` to check whether the attribute has a
16011606
/// registered JVP.

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2974,8 +2974,8 @@ ERROR(transpose_params_clause_param_not_differentiable,none,
29742974
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
29752975
ERROR(transposing_attr_overload_not_found,none,
29762976
"could not find function %0 with expected type %1", (DeclName, Type))
2977-
ERROR(transposing_attr_cant_use_named_wrt_params,none,
2978-
"cannot use named wrt parameters in '@transposing' attribute, found %0",
2977+
ERROR(transposing_attr_cannot_use_named_wrt_params,none,
2978+
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
29792979
(Identifier))
29802980
ERROR(transposing_attr_result_value_not_differentiable,none,
29812981
"'@transposing' attribute requires original function result to "

include/swift/AST/Types.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3139,12 +3139,11 @@ class AnyFunctionType : public TypeBase {
31393139
/// Given the type of an autodiff derivative function, returns the
31403140
/// corresponding original function type.
31413141
AnyFunctionType *getAutoDiffOriginalFunctionType();
3142-
3142+
31433143
/// Given the type of a transposing derivative function, returns the
31443144
/// corresponding original function type.
31453145
AnyFunctionType *
3146-
getTransposeOriginalFunctionType(TransposingAttr *attr,
3147-
IndexSubset *wrtParamIndices,
3146+
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices,
31483147
bool wrtSelf);
31493148

31503149
AnyFunctionType *getWithoutDifferentiability() const;

lib/AST/Attr.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,16 +1454,16 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
14541454
getTrailingObjects<ParsedAutoDiffParameter>());
14551455
}
14561456

1457-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1457+
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
14581458
SourceLoc atLoc, SourceRange baseRange,
1459-
bool linear,
1460-
IndexSubset *indices,
1459+
bool linear, IndexSubset *indices,
14611460
Optional<DeclNameWithLoc> jvp,
14621461
Optional<DeclNameWithLoc> vjp,
14631462
GenericSignature derivativeGenSig)
14641463
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
14651464
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
14661465
ParameterIndices(indices) {
1466+
setOriginalDeclaration(original);
14671467
setDerivativeGenericSignature(derivativeGenSig);
14681468
}
14691469

@@ -1483,19 +1483,26 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
14831483
}
14841484

14851485
DifferentiableAttr *
1486-
DifferentiableAttr::create(ASTContext &context, bool implicit,
1487-
SourceLoc atLoc, SourceRange baseRange,
1488-
bool linear, IndexSubset *indices,
1489-
Optional<DeclNameWithLoc> jvp,
1486+
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
1487+
SourceRange baseRange, bool linear,
1488+
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
14901489
Optional<DeclNameWithLoc> vjp,
14911490
GenericSignature derivativeGenSig) {
1492-
void *mem = context.Allocate(sizeof(DifferentiableAttr),
1493-
alignof(DifferentiableAttr));
1494-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1491+
auto &ctx = original->getASTContext();
1492+
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
1493+
alignof(DifferentiableAttr));
1494+
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
14951495
linear, indices, std::move(jvp),
14961496
std::move(vjp), derivativeGenSig);
14971497
}
14981498

1499+
void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
1500+
assert(decl && "Original declaration must be non-null");
1501+
assert(!OriginalDeclaration &&
1502+
"Original declaration cannot have already been set");
1503+
OriginalDeclaration = decl;
1504+
}
1505+
14991506
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15001507
JVPFunction = decl;
15011508
if (decl && !JVP)

lib/AST/Type.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4867,25 +4867,24 @@ makeFunctionType(ArrayRef<AnyFunctionType::Param> params, Type retTy,
48674867
// Compute the original function type corresponding to the given transpose
48684868
// function type.
48694869
AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
4870-
TransposingAttr *attr, IndexSubset *wrtParamIndices, bool wrtSelf) {
4870+
IndexSubset *wrtParamIndices, bool wrtSelf) {
48714871
unsigned transposeParamsIndex = 0;
48724872
bool isCurried = getResult()->is<AnyFunctionType>();
4873-
4873+
48744874
// Get the original function's result.
48754875
auto transposeParams = getParams();
48764876
auto transposeResult = getResult();
48774877
if (isCurried) {
4878-
auto method =
4879-
getAs<AnyFunctionType>()->getResult()->getAs<AnyFunctionType>();
4880-
transposeParams = method->getParams();
4881-
transposeResult = method->getResult();
4878+
auto methodType = getResult()->castTo<AnyFunctionType>();
4879+
transposeParams = methodType->getParams();
4880+
transposeResult = methodType->getResult();
48824881
}
4883-
4882+
48844883
Type originalResult;
48854884
if (isCurried) {
48864885
// If it's curried, then the first parameter in the curried type, which is
48874886
// the 'Self' type, is the original result (no matter if we are
4888-
// differentiating WRT self or aren't).
4887+
// transposing wrt self or not).
48894888
originalResult = getParams().front().getPlainType();
48904889
} else {
48914890
// If it's not curried, the last parameter, the tangent, is always the
@@ -4895,22 +4894,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
48954894
}
48964895
assert(originalResult);
48974896

4898-
auto wrtParams = attr->getParsedParameters();
48994897
SmallVector<TupleTypeElt, 4> transposeResultTypes;
49004898
// Return type of '@transposing' function can have single type or tuples
49014899
// of types.
4902-
if (auto t = transposeResult->getAs<TupleType>()) {
4903-
transposeResultTypes.append(t->getElements().begin(),
4904-
t->getElements().end());
4900+
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
4901+
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
4902+
transposeResultTupleType->getElements().end());
49054903
} else {
49064904
transposeResultTypes.push_back(transposeResult);
49074905
}
49084906
assert(!transposeResultTypes.empty());
49094907

4910-
// If the function is curried and is transposing WRT 'self', then grab
4908+
// If the function is curried and is transposing wrt 'self', then grab
49114909
// the type from the result list (guaranteed to be the first since 'self'
4912-
// is first in WRT list) and remove it. If it's still curried but not
4913-
// transposing WRT 'self', then the 'Self' type is the first parameter
4910+
// is first in wrt list) and remove it. If it is still curried but not
4911+
// transposing wrt 'self', then the 'Self' type is the first parameter
49144912
// in the method.
49154913
unsigned transposeResultTypesIndex = 0;
49164914
Type selfType;
@@ -4923,21 +4921,21 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
49234921
}
49244922

49254923
SmallVector<AnyFunctionType::Param, 8> originalParams;
4926-
unsigned numberOriginalParameters =
4927-
transposeParams.size() + wrtParams.size() - 1;
4928-
for (auto i : range(numberOriginalParameters)) {
4924+
unsigned originalParameterCount =
4925+
transposeParams.size() + wrtParamIndices->getNumIndices() - 1;
4926+
for (auto i : range(originalParameterCount)) {
49294927
// Need to check if it is the 'self' param since we handle it differently
49304928
// above.
4931-
bool lookingAtSelf = (i == (wrtParamIndices->getCapacity() - 1)) && wrtSelf;
4932-
bool isWrt = wrtParamIndices->contains(i);
4933-
if (isWrt && !lookingAtSelf) {
4934-
// If in WRT list, the item in the result tuple must be a parameter in the
4929+
bool lookingAtSelf = (i == wrtParamIndices->getCapacity() - 1) && wrtSelf;
4930+
if (wrtParamIndices->contains(i) && !lookingAtSelf) {
4931+
// If in wrt list, the item in the result tuple must be a parameter in the
49354932
// original function.
4936-
auto resultType = transposeResultTypes[transposeResultTypesIndex].getType();
4933+
auto resultType =
4934+
transposeResultTypes[transposeResultTypesIndex].getType();
49374935
originalParams.push_back(AnyFunctionType::Param(resultType));
49384936
transposeResultTypesIndex++;
49394937
} else {
4940-
// Else if not in the WRT list, the parameter in the transposing function
4938+
// Else if not in the wrt list, the parameter in the transposing function
49414939
// is a parameter in the original function.
49424940
originalParams.push_back(transposeParams[transposeParamsIndex]);
49434941
transposeParamsIndex++;

lib/Parse/ParseDecl.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3384,6 +3384,13 @@ void Parser::delayParseFromBeginningToHere(ParserPosition BeginParserPosition,
33843384
consumeToken();
33853385
}
33863386

3387+
// SWIFT_ENABLE_TENSORFLOW
3388+
static void setOriginalFunctionInDifferentiableAttributes(
3389+
DeclAttributes Attributes, Decl *D) {
3390+
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
3391+
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
3392+
}
3393+
33873394
/// Parse a single syntactic declaration and return a list of decl
33883395
/// ASTs. This can return multiple results for var decls that bind to multiple
33893396
/// values, structs that define a struct decl and a constructor, etc.
@@ -3788,6 +3795,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
37883795
Decl *D = DeclResult.get();
37893796
if (!declWasHandledAlready(D)) {
37903797
Handler(D);
3798+
// SWIFT_ENABLE_TENSORFLOW
37913799
if (auto FD = dyn_cast<FuncDecl>(D)) {
37923800
if (auto attr = D->getAttrs().getAttribute<QuotedAttr>()) {
37933801
// TODO(TF-718): Properly mangle names for quote decls.
@@ -3825,7 +3833,11 @@ Parser::parseDecl(ParseDeclOptions Flags,
38253833
Handler(quoteDecl);
38263834
}
38273835
}
3836+
// SWIFT_ENABLE_TENSORFLOW END
38283837
}
3838+
// SWIFT_ENABLE_TENSORFLOW
3839+
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
3840+
// SWIFT_ENABLE_TENSORFLOW END
38293841
}
38303842

38313843
if (!DeclResult.isParseError()) {
@@ -5579,6 +5591,12 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
55795591

55805592
accessors.record(*this, PrimaryVar, Invalid);
55815593

5594+
// SWIFT_ENABLE_TENSORFLOW
5595+
for (auto *accessor : accessors.Accessors)
5596+
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
5597+
accessor);
5598+
// SWIFT_ENABLE_TENSORFLOW END
5599+
55825600
return makeParserResult(PrimaryVar);
55835601
}
55845602

@@ -5833,6 +5851,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
58335851
pattern->forEachVariable([&](VarDecl *VD) {
58345852
VD->setStatic(StaticLoc.isValid());
58355853
VD->getAttrs() = Attributes;
5854+
// SWIFT_ENABLE_TENSORFLOW
5855+
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
5856+
// SWIFT_ENABLE_TENSORFLOW END
58365857
setLocalDiscriminator(VD);
58375858
Decls.push_back(VD);
58385859
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment) {
@@ -7087,6 +7108,12 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
70877108

70887109
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));
70897110

7111+
// SWIFT_ENABLE_TENSORFLOW
7112+
for (auto *accessor : accessors.Accessors)
7113+
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
7114+
accessor);
7115+
// SWIFT_ENABLE_TENSORFLOW END
7116+
70907117
// No need to setLocalDiscriminator because subscripts cannot
70917118
// validly appear outside of type decls.
70927119
return makeParserResult(Status, Subscript);

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
628628
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
629629
derivativeGenSig = extDecl->getGenericSignature();
630630
auto *diffableAttr = DifferentiableAttr::create(
631-
C, /*implicit*/ true, SourceLoc(), SourceLoc(),
632-
/*linear*/ false, {}, None, None, derivativeGenSig);
631+
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
632+
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
633+
derivativeGenSig);
633634
member->getAttrs().add(diffableAttr);
634635
// Set getter `@differentiable` attribute parameter indices.
635636
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));

0 commit comments

Comments
 (0)