Skip to content

Commit e941adc

Browse files
authored
[AutoDiff] Canonicalize SIL type for JVP/VJP methods. (#24775)
Canonicalize JVP/VJP type calculation for methods so that JVP/VJP methods return a linear map that take/return self's tangent/cotangent last, instead of first. This matches the type calculation logic for top-level functions. Changes: - Remove method-specific logic from `SILFunctionType::getAutoDiffAssociatedFunctionType`. Handle `self` like a normal differentiation parameter. - Thunk user-defined JVP/VJP methods, reordering the position of self's tangent/cotangent in returned linear maps. - Relevant code: `SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk` and `SILGenFunction::getOrCreateAutoDiffLinearMapReorderingThunk`. - Change AST method JVP/VJP type computation to match SIL. - Performs self-parameter-reordering for protocol witnesses. - Relevant code: `AutoDiffParameterIndices::getSubsetParameterTypes`, `AnyFunctionType::getAutoDiffAssociatedFunctionType`. Move some functions to a common location: - `AnyFunctionType::getAutoDiffOriginalFunctionType` - `autodiff::getAutoDiffFunctionLinkage`
1 parent 3d3f9e4 commit e941adc

File tree

16 files changed

+574
-263
lines changed

16 files changed

+574
-263
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class AnyFunctionType;
7777
class AutoDiffIndexSubset;
7878
class AutoDiffParameterIndicesBuilder;
7979
class Type;
80+
enum class SILLinkage : uint8_t;
8081

8182
/// Identifies a subset of a function's parameters.
8283
///
@@ -148,14 +149,17 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
148149
///
149150
/// functionType = (A, B) -> (C, D) -> R
150151
/// if "A", "C", and "D" are in the set,
151-
/// ==> pushes {A, C, D} to `paramTypes`.
152+
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is false,
153+
/// or pushes {C, D, A} otherwise.
152154
///
153155
/// functionType = (Self) -> (A, B, C) -> R
154156
/// if "Self" and "C" are in the set,
155-
/// ==> pushes {Self, C} to `paramTypes`.
157+
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is false,
158+
/// or pushes {C, Self} otherwise.
156159
///
157160
void getSubsetParameterTypes(AnyFunctionType *functionType,
158-
SmallVectorImpl<Type> &paramTypes) const;
161+
SmallVectorImpl<Type> &paramTypes,
162+
bool reverseCurryLevels = false) const;
159163

160164
/// Returns a bitvector for the SILFunction parameters corresponding to the
161165
/// parameters in this set. In particular, this explodes tuples. For example,
@@ -465,6 +469,10 @@ struct SILAutoDiffIndices {
465469

466470
bool operator==(const SILAutoDiffIndices &other) const;
467471

472+
bool operator!=(const SILAutoDiffIndices &other) const {
473+
return !(*this == other);
474+
};
475+
468476
/// Queries whether the function's parameter with index `parameterIndex` is
469477
/// one of the parameters to differentiate with respect to.
470478
bool isWrtParameter(unsigned parameterIndex) const {
@@ -567,13 +575,21 @@ getOffsetForAutoDiffAssociatedFunction(unsigned order,
567575
unsigned
568576
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);
569577

570-
// Retrieve config from the function name of a variant of
571-
// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
572-
// Returns true if the function name is parsed successfully.
578+
/// Retrieve config from the function name of a variant of
579+
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
580+
/// Returns true if the function name is parsed successfully.
573581
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
574582
AutoDiffAssociatedFunctionKind &kind,
575583
unsigned &arity, unsigned &order,
576584
bool &rethrows);
585+
586+
/// Computes the correct linkage for associated functions given the linkage of
587+
/// the original function. If the original linkage is not external and
588+
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
589+
/// return hidden linkage.
590+
SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
591+
bool isAssocFnExported);
592+
577593
} // end namespace autodiff
578594

579595
class BuiltinFloatType;

include/swift/AST/Types.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3072,14 +3072,28 @@ class AnyFunctionType : public TypeBase {
30723072
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
30733073
/// of the corresponding autodiff associated function.
30743074
///
3075-
/// \note The original function type (`self`) need not be `@differentiable`,
3076-
/// and the resulting function will preserve all `ExtInfo` of the original
3075+
/// By default, if the original type has a self parameter list and parameter
3076+
/// indices include self, the computed associated function type will return a
3077+
/// linear map taking/returning self's tangent/cotangent *last* instead of
3078+
/// first, for consistency with SIL.
3079+
///
3080+
/// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to
3081+
/// appear first. This should be used during type-checking, e.g.
3082+
/// type-checking `@differentiable` and `@differentiating` attributes.
3083+
///
3084+
/// \note The original function type (`self`) need not be `@differentiable`.
3085+
/// The resulting function will preserve all `ExtInfo` of the original
30773086
/// function, including `@differentiable`.
30783087
AnyFunctionType *getAutoDiffAssociatedFunctionType(
30793088
AutoDiffParameterIndices *indices, unsigned resultIndex,
30803089
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
30813090
LookupConformanceFn lookupConformance,
3082-
GenericSignature *whereClauseGenericSignature = nullptr);
3091+
GenericSignature *whereClauseGenericSignature = nullptr,
3092+
bool makeSelfParamFirst = false);
3093+
3094+
/// Given the type of an autodiff associated function, returns the
3095+
/// corresponding original function type.
3096+
AnyFunctionType *getAutoDiffOriginalFunctionType();
30833097

30843098
AnyFunctionType *getWithoutDifferentiability() const;
30853099

lib/AST/AutoDiff.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "swift/AST/AutoDiff.h"
1414
#include "swift/AST/Module.h"
1515
#include "swift/AST/Types.h"
16+
#include "swift/SIL/SILLinkage.h"
1617
#include "swift/Basic/LLVM.h"
1718
#include "swift/Basic/Range.h"
1819
#include "llvm/ADT/STLExtras.h"
@@ -88,6 +89,30 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
8889
return operationName.empty();
8990
}
9091

92+
SILLinkage autodiff::getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
93+
bool isAssocFnExported) {
94+
// If the original is defined externally, then the AD pass is just generating
95+
// associated functions for use in the current module and therefore these
96+
// associated functions should not be visible outside the module.
97+
if (isAvailableExternally(originalLinkage))
98+
return SILLinkage::Hidden;
99+
100+
// If the original is public, then external modules may need to link the
101+
// associated function. Return the linkage of the original function, unless
102+
// the associated function is not exported (i.e. differentiation is not
103+
// explicitly requested via a `[differentiable]` attribute on the original
104+
// function).
105+
if (originalLinkage == SILLinkage::Public ||
106+
originalLinkage == SILLinkage::PublicNonABI ||
107+
originalLinkage == SILLinkage::Shared)
108+
return isAssocFnExported ? originalLinkage : SILLinkage::Hidden;
109+
110+
// Otherwise, the original function is defined and used only in the current
111+
// module, so external modules will never try to access the associated
112+
// function. Make the associated function hidden.
113+
return SILLinkage::Hidden;
114+
}
115+
91116
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
92117
/// the given `string` generated by `getString()`. If the string is invalid,
93118
/// returns nullptr.
@@ -140,15 +165,16 @@ static void unwrapCurryLevels(AnyFunctionType *fnTy,
140165
/// ==> pushes {A, C} to `paramTypes`.
141166
///
142167
/// functionType = (A, B) -> (C, D) -> R
143-
/// if "A", "C", and "D" are in the set,
144-
/// ==> pushes {A, C, D} to `paramTypes`.
168+
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is true,
169+
/// or pushes {C, D, A} otherwise.
145170
///
146171
/// functionType = (Self) -> (A, B, C) -> R
147-
/// if "Self" and "C" are in the set,
148-
/// ==> pushes {Self, C} to `paramTypes`.
172+
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is true,
173+
/// or pushes {C, Self} otherwise.
149174
///
150175
void AutoDiffParameterIndices::getSubsetParameterTypes(
151-
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes) const {
176+
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes,
177+
bool reverseCurryLevels) const {
152178
SmallVector<AnyFunctionType *, 2> curryLevels;
153179
unwrapCurryLevels(functionType, curryLevels);
154180

@@ -159,6 +185,13 @@ void AutoDiffParameterIndices::getSubsetParameterTypes(
159185
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
160186
}
161187

188+
// If `reverseCurryLevels` is true, reverse the curry levels and offsets.
189+
if (reverseCurryLevels) {
190+
std::reverse(curryLevels.begin(), curryLevels.end());
191+
std::reverse(curryLevelParameterIndexOffsets.begin(),
192+
curryLevelParameterIndexOffsets.end());
193+
}
194+
162195
for (unsigned curryLevelIndex : indices(curryLevels)) {
163196
auto *curryLevel = curryLevels[curryLevelIndex];
164197
unsigned parameterIndexOffset =

lib/AST/Type.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4444,7 +4444,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44444444
AutoDiffParameterIndices *indices, unsigned resultIndex,
44454445
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
44464446
LookupConformanceFn lookupConformance,
4447-
GenericSignature *whereClauseGenSig) {
4447+
GenericSignature *whereClauseGenSig, bool makeSelfParamFirst) {
44484448
// JVP: (T...) -> ((R...),
44494449
// (T.TangentVector...) -> (R.TangentVector...))
44504450
// VJP: (T...) -> ((R...),
@@ -4460,12 +4460,17 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44604460
auto &ctx = getASTContext();
44614461

44624462
SmallVector<Type, 8> wrtParamTypes;
4463-
indices->getSubsetParameterTypes(this, wrtParamTypes);
4463+
indices->getSubsetParameterTypes(
4464+
this, wrtParamTypes, /*reverseCurryLevels*/ !makeSelfParamFirst);
44644465

4465-
// Unwrap curry levels.
4466+
// Unwrap curry levels. At most, two parameter lists are necessary, for
4467+
// curried method types with a `(Self)` parameter list.
44664468
SmallVector<AnyFunctionType *, 2> curryLevels;
4467-
auto *currentLevel = this->eraseDynamicSelfType()->castTo<AnyFunctionType>();
4468-
while (currentLevel != nullptr) {
4469+
auto *currentLevel = eraseDynamicSelfType()->castTo<AnyFunctionType>();
4470+
for (unsigned i : range(2)) {
4471+
(void)i;
4472+
if (currentLevel == nullptr)
4473+
break;
44694474
curryLevels.push_back(currentLevel);
44704475
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
44714476
}
@@ -4566,6 +4571,45 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
45664571
return associatedFunction;
45674572
}
45684573

4574+
// SWIFT_ENABLE_TENSORFLOW
4575+
// Compute the original function type corresponding to the given derivative
4576+
// function type.
4577+
AnyFunctionType *
4578+
AnyFunctionType::getAutoDiffOriginalFunctionType() {
4579+
// Unwrap curry levels. At most, two parameter lists are necessary, for
4580+
// curried method types with a `(Self)` parameter list.
4581+
SmallVector<AnyFunctionType *, 2> curryLevels;
4582+
auto *currentLevel = this;
4583+
for (unsigned i : range(2)) {
4584+
(void)i;
4585+
if (currentLevel == nullptr)
4586+
break;
4587+
curryLevels.push_back(currentLevel);
4588+
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
4589+
}
4590+
4591+
auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
4592+
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
4593+
"Expected derivative result to be a two-element tuple");
4594+
auto originalResult = derivativeResult->getElement(0).getType();
4595+
auto *originalType = makeFunctionType(
4596+
curryLevels.back(), curryLevels.back()->getParams(), originalResult,
4597+
curryLevels.size() == 1 ? getOptGenericSignature() : nullptr);
4598+
4599+
// Wrap the associated function type in additional curry levels.
4600+
auto curryLevelsWithoutLast =
4601+
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
4602+
for (auto pair : enumerate(reversed(curryLevelsWithoutLast))) {
4603+
unsigned i = pair.index();
4604+
AnyFunctionType *curryLevel = pair.value();
4605+
originalType = makeFunctionType(
4606+
curryLevel, curryLevel->getParams(), originalType,
4607+
i == curryLevelsWithoutLast.size() - 1 ? getOptGenericSignature()
4608+
: nullptr);
4609+
}
4610+
return originalType;
4611+
}
4612+
45694613
AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
45704614
SmallVector<Param, 8> newParams;
45714615
for (auto &param : getParams()) {

lib/SIL/SILFunctionType.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,23 +216,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
216216
parameterIndices->contains(index);
217217
};
218218

219-
// Calculate WRT parameter infos, in the order that they should appear in the
220-
// results/parameters of the differential/pullback.
219+
// Calculate differentiation parameter infos.
221220
SmallVector<SILParameterInfo, 4> wrtParams;
222-
// Make the self parameter appear first in the results/parameters of the
223-
// differntial/pullback, even though it's the last parameter of the original
224-
// method.
225-
if (getExtInfo().hasSelfParam() &&
226-
isWrtIndex(getNumParameters() - 1))
227-
wrtParams.push_back(getParameters()[getNumParameters() - 1]);
228-
for (auto valueAndIndex : enumerate(getParameters())) {
229-
// Skip the self parameter because we have already added it.
230-
if (getExtInfo().hasSelfParam() &&
231-
valueAndIndex.index() == getNumParameters() - 1)
232-
continue;
221+
for (auto valueAndIndex : enumerate(getParameters()))
233222
if (isWrtIndex(valueAndIndex.index()))
234223
wrtParams.push_back(valueAndIndex.value());
235-
}
236224

237225
CanSILFunctionType closureType;
238226
switch (kind) {

lib/SILGen/SILGen.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,60 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
766766
assert(!F->isExternalDeclaration() && "did not emit any function body?!");
767767
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
768768
F->print(llvm::dbgs()));
769+
770+
// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
771+
if (constant.hasDecl()) {
772+
auto *AFD = constant.getAbstractFunctionDecl();
773+
// Continue only if original function is an instance method.
774+
if (AFD && AFD->isInstanceMember() &&
775+
F->getLoweredFunctionType()->hasSelfParam()) {
776+
// Jointly iterate over AST `@differentiable` attributes and SIL
777+
// `[differentiable]` attributes.
778+
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
779+
auto silDiffAttrs = F->getDifferentiableAttrs();
780+
for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) {
781+
auto *diffAttr = const_cast<DifferentiableAttr *>(std::get<0>(pair));
782+
auto *silDiffAttr = std::get<1>(pair);
783+
// Compute autodiff indices.
784+
auto paramIndices = diffAttr->getParameterIndices();
785+
auto loweredParamIndices = paramIndices->getLowered(
786+
getASTContext(),
787+
AFD->getInterfaceType()->castTo<AnyFunctionType>());
788+
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
789+
assert(silDiffAttr->getIndices() == indices &&
790+
"Expected matching @differentiable and [differentiable]");
791+
792+
// If user-defined JVP/VJP is not differentiable wrt self or is only
793+
// differentiable wrt self, reordering is not necessary. Continue.
794+
auto selfParamIndex =
795+
F->getArgumentsWithoutIndirectResults().size() - 1;
796+
bool isWrtSelf = indices.isWrtParameter(selfParamIndex);
797+
if (!isWrtSelf || indices.parameters->getNumIndices() == 1)
798+
continue;
799+
800+
// Thunk JVP method, if it is defined.
801+
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
802+
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
803+
auto *thunk = getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
804+
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
805+
jvpFn->isSerialized());
806+
silDiffAttr->setJVPName(thunk->getName());
807+
// Unset JVP so that TBDGen triggers.
808+
diffAttr->setJVPFunction(nullptr);
809+
}
810+
// Thunk VJP method, if it is defined.
811+
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
812+
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
813+
auto *thunk = getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
814+
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
815+
vjpFn->isSerialized());
816+
silDiffAttr->setVJPName(thunk->getName());
817+
// Unset VJP so that TBDGen triggers.
818+
diffAttr->setVJPFunction(nullptr);
819+
}
820+
}
821+
}
822+
}
769823
F->verify();
770824
}
771825

lib/SILGen/SILGen.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
182182
CanSILFunctionType toType,
183183
CanType dynamicSelfType);
184184

185+
// SWIFT_ENABLE_TENSORFLOW
186+
/// Get or create a thunk for reordering autodiff associated functions with a
187+
/// self parameter, so that self appears as:
188+
/// - The last parameter in the returned differential.
189+
/// - The last result in the returned pullback.
190+
SILFunction *getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
191+
SILFunction *original, SILAutoDiffIndices &indices,
192+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
193+
IsSerialized_t isSerialized);
194+
185195
/// Determine whether the given class has any instance variables that
186196
/// need to be destroyed.
187197
bool hasNonTrivialIVars(ClassDecl *cd);

lib/SILGen/SILGenFunction.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,19 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
17621762
CanType &dynamicSelfType,
17631763
bool withoutActuallyEscaping=false);
17641764

1765+
// SWIFT_ENABLE_TENSORFLOW
1766+
//===--------------------------------------------------------------------===//
1767+
// Differentiation thunks
1768+
//===--------------------------------------------------------------------===//
1769+
1770+
/// Get or create a thunk for reordering linear maps that are differentiable
1771+
/// wrt self, so that self appears as:
1772+
/// - The last parameter in the differential.
1773+
/// - The last result in the pullback.
1774+
SILFunction *getOrCreateAutoDiffLinearMapReorderingThunk(
1775+
AutoDiffAssociatedFunctionKind assocFnKind,
1776+
CanSILFunctionType fromType, CanSILFunctionType toType);
1777+
17651778
//===--------------------------------------------------------------------===//
17661779
// NoEscaping to Escaping closure thunk
17671780
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)