Skip to content

Commit 95411ac

Browse files
committed
Merge branch 'tensorflow' of github.com:apple/swift into tensorflow-merge
2 parents 3cbc148 + b88a119 commit 95411ac

25 files changed

+652
-220
lines changed

include/swift/AST/ASTContext.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ namespace swift {
112112
class IndexSubset;
113113
// SWIFT_ENABLE_TENSORFLOW
114114
struct AutoDiffConfig;
115-
class VectorSpace;
115+
struct AutoDiffDerivativeFunctionKind;
116+
class DerivativeAttr;
116117
class DifferentiableAttr;
118+
class VectorSpace;
117119
// SWIFT_ENABLE_TENSORFLOW END
118120

119121
enum class KnownProtocolKind : uint8_t;
@@ -290,11 +292,26 @@ class ASTContext final {
290292
/// Cache of autodiff-associated vector spaces.
291293
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;
292294

293-
/// Cache of `@differentiable` attributes keyed by parameter indices. This
294-
/// helps us diagnose multiple `@differentiable`s that are with respect to the
295-
/// same set of parameters.
295+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
296+
/// diagnose duplicate `@differentiable` attributes for the same key.
297+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
298+
// signature as a key is possible. It requires derivative generic signature
299+
// mangling to avoid name collisions for SIL derivative functions with the
300+
// same parameter indices but different derivative generic signatures.
296301
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
297302
DifferentiableAttrs;
303+
304+
/// Cache of `@derivative` attributes keyed by parameter indices and
305+
/// derivative function kind. Used to diagnose duplicate `@derivative`
306+
/// attributes for the same key.
307+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
308+
// signature as a key is possible. It requires derivative generic signature
309+
// mangling to avoid name collisions for SIL derivative functions with the
310+
// same parameter indices but different derivative generic signatures.
311+
llvm::DenseMap<
312+
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
313+
DerivativeAttr *>
314+
DerivativeAttrs;
298315
// SWIFT_ENABLE_TENSORFLOW END
299316

300317
private:

include/swift/AST/Attr.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,12 +541,12 @@ DECL_ATTR(quoted, Quoted,
541541
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
542542
DECL_ATTR(differentiating, Differentiating,
543543
OnFunc | LongAttribute | AllowMultipleAttributes |
544-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
545-
NotSerialized, 98)
544+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
545+
98)
546546
DECL_ATTR(derivative, Derivative,
547547
OnFunc | LongAttribute | AllowMultipleAttributes |
548-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
549-
NotSerialized, 99)
548+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
549+
99)
550550
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
551551
OnAccessor | OnFunc | OnConstructor | OnSubscript |
552552
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |

include/swift/AST/Attr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,7 +1873,7 @@ class DifferentiableAttr final
18731873

18741874
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
18751875
SourceRange baseRange, bool linear,
1876-
IndexSubset *indices,
1876+
IndexSubset *parameterIndices,
18771877
Optional<DeclNameWithLoc> jvp,
18781878
Optional<DeclNameWithLoc> vjp,
18791879
GenericSignature derivativeGenericSignature);
@@ -1887,9 +1887,10 @@ class DifferentiableAttr final
18871887
Optional<DeclNameWithLoc> vjp,
18881888
TrailingWhereClause *clause);
18891889

1890-
static DifferentiableAttr *create(Decl *original, bool implicit,
1891-
SourceLoc atLoc, SourceRange baseRange,
1892-
bool linear, IndexSubset *indices,
1890+
static DifferentiableAttr *create(AbstractFunctionDecl *original,
1891+
bool implicit, SourceLoc atLoc,
1892+
SourceRange baseRange, bool linear,
1893+
IndexSubset *parameterIndices,
18931894
Optional<DeclNameWithLoc> jvp,
18941895
Optional<DeclNameWithLoc> vjp,
18951896
GenericSignature derivativeGenSig);
@@ -1979,6 +1980,8 @@ class DerivativeAttr final
19791980
unsigned NumParsedParameters = 0;
19801981
/// The differentiation parameters' indices, resolved by the type checker.
19811982
IndexSubset *ParameterIndices = nullptr;
1983+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1984+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
19821985

19831986
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
19841987
DeclNameWithLoc original,
@@ -2007,6 +2010,12 @@ class DerivativeAttr final
20072010
OriginalFunction = decl;
20082011
}
20092012

2013+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
2014+
assert(Kind && "Derivative function kind has not yet been resolved");
2015+
return *Kind;
2016+
}
2017+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
2018+
20102019
/// The parsed differentiation parameters, i.e. the list of parameters
20112020
/// specified in 'wrt:'.
20122021
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ struct AutoDiffConfig {
306306
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
307307
const AutoDiffDerivativeFunctionKind kind;
308308
IndexSubset *const parameterIndices;
309+
// TODO(TF-680): Mangle derivative generic signature requirements as well.
309310

310311
AutoDiffDerivativeFunctionIdentifier(
311312
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
@@ -508,6 +509,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
508509
}
509510
};
510511

512+
template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
513+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
514+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
515+
DenseMapInfo<unsigned>::getEmptyKey());
516+
}
517+
518+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
519+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
520+
DenseMapInfo<unsigned>::getTombstoneKey());
521+
}
522+
523+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
524+
return DenseMapInfo<unsigned>::getHashValue(Val);
525+
}
526+
527+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
528+
const AutoDiffDerivativeFunctionKind &RHS) {
529+
return LHS == RHS;
530+
}
531+
};
532+
511533
template<> struct DenseMapInfo<SILAutoDiffIndices> {
512534
static SILAutoDiffIndices getEmptyKey() {
513535
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

lib/AST/Attr.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -958,8 +958,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
958958
Printer.printAttrName("@derivative");
959959
Printer << "(of: ";
960960
auto *attr = cast<DerivativeAttr>(this);
961-
auto *derivative = cast<AbstractFunctionDecl>(D);
962961
Printer << attr->getOriginalFunctionName().Name;
962+
auto *derivative = cast<AbstractFunctionDecl>(D);
963963
auto diffParamsString = getDifferentiationParametersClauseString(
964964
derivative, attr->getParameterIndices(), attr->getParsedParameters());
965965
if (!diffParamsString.empty())
@@ -973,8 +973,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
973973
Printer.printAttrName("@transpose");
974974
Printer << '(';
975975
auto *attr = cast<TransposeAttr>(this);
976-
auto *transpose = cast<AbstractFunctionDecl>(D);
977976
Printer << attr->getOriginalFunctionName().Name;
977+
auto *transpose = cast<AbstractFunctionDecl>(D);
978978
auto transParamsString = getTransposedParametersClauseString(
979979
transpose, attr->getParameterIndices(), attr->getParsedParameters());
980980
if (!transParamsString.empty())
@@ -1504,16 +1504,24 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
15041504
}
15051505

15061506
DifferentiableAttr *
1507-
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
1508-
SourceRange baseRange, bool linear,
1509-
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
1507+
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
1508+
SourceLoc atLoc, SourceRange baseRange, bool linear,
1509+
IndexSubset *parameterIndices,
1510+
Optional<DeclNameWithLoc> jvp,
15101511
Optional<DeclNameWithLoc> vjp,
15111512
GenericSignature derivativeGenSig) {
15121513
auto &ctx = original->getASTContext();
15131514
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
15141515
alignof(DifferentiableAttr));
1516+
// Register derivative function configuration for the given original
1517+
// declaration.
1518+
// NOTE(TF-1038): `@differentiable` attributes currently always have
1519+
// effective result indices `{0}` (the first and only result index).
1520+
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
1521+
original->addDerivativeFunctionConfiguration(
1522+
{parameterIndices, resultIndices, derivativeGenSig});
15151523
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
1516-
linear, indices, std::move(jvp),
1524+
linear, parameterIndices, std::move(jvp),
15171525
std::move(vjp), derivativeGenSig);
15181526
}
15191527

lib/SILGen/SILGen.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
778778
diffAttr->getDerivativeGenericSignature());
779779
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
780780
}
781+
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
782+
SILFunction *jvp = nullptr;
783+
SILFunction *vjp = nullptr;
784+
switch (derivAttr->getDerivativeKind()) {
785+
case AutoDiffDerivativeFunctionKind::JVP:
786+
jvp = F;
787+
break;
788+
case AutoDiffDerivativeFunctionKind::VJP:
789+
vjp = F;
790+
break;
791+
}
792+
auto *origAFD = derivAttr->getOriginalFunction();
793+
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
794+
auto derivativeGenSig = AFD->getGenericSignature();
795+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
796+
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
797+
derivativeGenSig);
798+
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
799+
derivAttr);
800+
}
781801
};
782802
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
783803
if (accessor->isGetter())
@@ -790,21 +810,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
790810
void SILGenModule::emitDifferentiabilityWitness(
791811
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
792812
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
793-
const DeclAttribute *diffAttr) {
813+
const DeclAttribute *attr) {
814+
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
794815
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
795816
auto origSilFnType = originalFunction->getLoweredFunctionType();
796-
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
797-
config.parameterIndices, origFnType);
817+
auto *silParamIndices =
818+
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
798819
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
799820
// parameters corresponding to captured variables. These parameters do not
800821
// appear in the type of `origFnType`.
801822
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
802823
// take `CaptureInfo` into account.
803-
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
804-
loweredParamIndices = loweredParamIndices->extendingCapacity(
824+
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
825+
silParamIndices = silParamIndices->extendingCapacity(
805826
getASTContext(), origSilFnType->getNumParameters());
806827
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
807-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
828+
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);
808829

809830
// Self reordering thunk is necessary if wrt at least two parameters,
810831
// including self.
@@ -818,14 +839,22 @@ void SILGenModule::emitDifferentiabilityWitness(
818839
};
819840
bool reorderSelf = shouldReorderSelf();
820841

821-
// Create new SIL differentiability witness.
842+
// Get or create new SIL differentiability witness.
843+
// Witness already exists when there are two `@derivative` attributes (JVP and
844+
// VJP) for the same derivative function configuration.
822845
// Witness JVP and VJP are set below.
823-
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
824-
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
825-
config.resultIndices, config.derivativeGenericSignature,
826-
/*jvp*/ nullptr, /*vjp*/ nullptr,
827-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
828-
diffAttr);
846+
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
847+
config.derivativeGenericSignature);
848+
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
849+
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
850+
if (!diffWitness) {
851+
diffWitness = SILDifferentiabilityWitness::createDefinition(
852+
M, originalFunction->getLinkage(), originalFunction,
853+
silConfig.parameterIndices, silConfig.resultIndices,
854+
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
855+
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
856+
attr);
857+
}
829858

830859
// Set derivative function in differentiability witness.
831860
auto setDerivativeInDifferentiabilityWitness =

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,12 +1392,14 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
13921392

13931393
auto thunkTy = thunk->getLoweredFunctionType();
13941394
auto thunkResult = thunkTy->getSingleResult();
1395-
if (auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>()) {
1396-
// Construct new curry thunk type with `@differentiable` result.
1397-
auto diffableResultFnTy = resultFnTy->getWithExtInfo(
1398-
resultFnTy->getExtInfo()
1399-
.withDifferentiabilityKind(DifferentiabilityKind::Normal));
1400-
auto newThunkResult = thunkResult.getWithInterfaceType(diffableResultFnTy);
1395+
if (auto resultFnTy =
1396+
thunkResult.getInterfaceType()->getAs<SILFunctionType>()) {
1397+
// Construct new curry thunk type with `@differentiable` function
1398+
// result.
1399+
auto diffResultFnTy = resultFnTy->getWithExtInfo(
1400+
resultFnTy->getExtInfo().withDifferentiabilityKind(
1401+
DifferentiabilityKind::Normal));
1402+
auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
14011403
auto thunkType = SILFunctionType::get(
14021404
thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(),
14031405
thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
@@ -1425,12 +1427,18 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
14251427
cloner.run();
14261428
auto *retInst =
14271429
cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
1428-
SILBuilder thunkBuilder(retInst);
1429-
auto *dfi = context.createDifferentiableFunction(thunkBuilder, loc,
1430-
parameterIndices,
1431-
retInst->getOperand());
1430+
auto returnValue = retInst->getOperand();
1431+
// Create `differentiable_function` instruction directly after the
1432+
// defining instruction (e.g. `partial_apply`) of the returned value.
1433+
// Note: `differentiable_function` is not created at the end of the
1434+
// new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
1435+
SILBuilder dfiBuilder(
1436+
std::next(returnValue->getDefiningInstruction()->getIterator()));
1437+
auto *dfi = context.createDifferentiableFunction(
1438+
dfiBuilder, loc, parameterIndices, returnValue);
14321439
context.setResultIndex(dfi, resultIndex);
1433-
thunkBuilder.createReturn(loc, dfi);
1440+
dfiBuilder.setInsertionPoint(newThunk->findReturnBB());
1441+
dfiBuilder.createReturn(loc, dfi);
14341442
retInst->eraseFromParent();
14351443

14361444
context.recordGeneratedFunction(newThunk);
@@ -1450,12 +1458,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
14501458
auto *newApply = builder.createApply(
14511459
ai->getLoc(), newThunkRef, ai->getSubstitutionMap(), newArgs,
14521460
ai->isNonThrowing());
1453-
for (auto arg : newArgsToDestroy) {
1454-
if (arg->getType().isObject())
1455-
builder.emitDestroyValueOperation(loc, arg);
1456-
else
1457-
builder.emitDestroyAddr(loc, arg);
1458-
}
1461+
for (auto arg : newArgsToDestroy)
1462+
builder.emitDestroyOperation(loc, arg);
14591463
for (auto *alloc : newBuffersToDealloc)
14601464
builder.createDeallocStack(loc, alloc);
14611465
return newApply;

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
586586
auto memberAssocContextualType =
587587
parentDC->mapTypeIntoContext(memberAssocInterfaceType);
588588
newMember->setInterfaceType(memberAssocInterfaceType);
589-
// newMember->setType(memberAssocContextualType);
590589
Pattern *memberPattern =
591590
new (C) NamedPattern(newMember, /*implicit*/ true);
592591
memberPattern->setType(memberAssocContextualType);
@@ -623,10 +622,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
623622
derivativeGenSig = extDecl->getGenericSignature();
624623
auto *diffableAttr = DifferentiableAttr::create(
625624
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
626-
/*linear*/ false, {}, None, None, derivativeGenSig);
625+
/*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
626+
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
627627
member->getAttrs().add(diffableAttr);
628-
// Set getter `@differentiable` attribute parameter indices.
629-
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));
630628
}
631629
}
632630

0 commit comments

Comments
 (0)