Skip to content

[AutoDiff] Canonicalize SIL type for JVP/VJP methods. #24775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class AnyFunctionType;
class AutoDiffIndexSubset;
class AutoDiffParameterIndicesBuilder;
class Type;
enum class SILLinkage : uint8_t;

/// Identifies a subset of a function's parameters.
///
Expand Down Expand Up @@ -148,14 +149,17 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
///
/// functionType = (A, B) -> (C, D) -> R
/// if "A", "C", and "D" are in the set,
/// ==> pushes {A, C, D} to `paramTypes`.
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is false,
/// or pushes {C, D, A} otherwise.
///
/// functionType = (Self) -> (A, B, C) -> R
/// if "Self" and "C" are in the set,
/// ==> pushes {Self, C} to `paramTypes`.
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is false,
/// or pushes {C, Self} otherwise.
///
void getSubsetParameterTypes(AnyFunctionType *functionType,
SmallVectorImpl<Type> &paramTypes) const;
SmallVectorImpl<Type> &paramTypes,
bool reverseCurryLevels = false) const;

/// Returns a bitvector for the SILFunction parameters corresponding to the
/// parameters in this set. In particular, this explodes tuples. For example,
Expand Down Expand Up @@ -465,6 +469,10 @@ struct SILAutoDiffIndices {

bool operator==(const SILAutoDiffIndices &other) const;

bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};

/// Queries whether the function's parameter with index `parameterIndex` is
/// one of the parameters to differentiate with respect to.
bool isWrtParameter(unsigned parameterIndex) const {
Expand Down Expand Up @@ -567,13 +575,21 @@ getOffsetForAutoDiffAssociatedFunction(unsigned order,
unsigned
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);

// Retrieve config from the function name of a variant of
// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
// Returns true if the function name is parsed successfully.
/// Retrieve config from the function name of a variant of
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
AutoDiffAssociatedFunctionKind &kind,
unsigned &arity, unsigned &order,
bool &rethrows);

/// Computes the correct linkage for associated functions given the linkage of
/// the original function. If the original linkage is not external and
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
/// return hidden linkage.
SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
bool isAssocFnExported);

} // end namespace autodiff

class BuiltinFloatType;
Expand Down
20 changes: 17 additions & 3 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3072,14 +3072,28 @@ class AnyFunctionType : public TypeBase {
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
/// of the corresponding autodiff associated function.
///
/// \note The original function type (`self`) need not be `@differentiable`,
/// and the resulting function will preserve all `ExtInfo` of the original
/// By default, if the original type has a self parameter list and parameter
/// indices include self, the computed associated function type will return a
/// linear map taking/returning self's tangent/cotangent *last* instead of
/// first, for consistency with SIL.
///
/// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to
/// appear first. This should be used during type-checking, e.g.
/// type-checking `@differentiable` and `@differentiating` attributes.
///
/// \note The original function type (`self`) need not be `@differentiable`.
/// The resulting function will preserve all `ExtInfo` of the original
/// function, including `@differentiable`.
AnyFunctionType *getAutoDiffAssociatedFunctionType(
AutoDiffParameterIndices *indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenericSignature = nullptr);
GenericSignature *whereClauseGenericSignature = nullptr,
bool makeSelfParamFirst = false);

/// Given the type of an autodiff associated function, returns the
/// corresponding original function type.
AnyFunctionType *getAutoDiffOriginalFunctionType();

AnyFunctionType *getWithoutDifferentiability() const;

Expand Down
43 changes: 38 additions & 5 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Module.h"
#include "swift/AST/Types.h"
#include "swift/SIL/SILLinkage.h"
#include "swift/Basic/LLVM.h"
#include "swift/Basic/Range.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -88,6 +89,30 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
return operationName.empty();
}

SILLinkage autodiff::getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
bool isAssocFnExported) {
// If the original is defined externally, then the AD pass is just generating
// associated functions for use in the current module and therefore these
// associated functions should not be visible outside the module.
if (isAvailableExternally(originalLinkage))
return SILLinkage::Hidden;

// If the original is public, then external modules may need to link the
// associated function. Return the linkage of the original function, unless
// the associated function is not exported (i.e. differentiation is not
// explicitly requested via a `[differentiable]` attribute on the original
// function).
if (originalLinkage == SILLinkage::Public ||
originalLinkage == SILLinkage::PublicNonABI ||
originalLinkage == SILLinkage::Shared)
return isAssocFnExported ? originalLinkage : SILLinkage::Hidden;

// Otherwise, the original function is defined and used only in the current
// module, so external modules will never try to access the associated
// function. Make the associated function hidden.
return SILLinkage::Hidden;
}

/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
/// the given `string` generated by `getString()`. If the string is invalid,
/// returns nullptr.
Expand Down Expand Up @@ -140,15 +165,16 @@ static void unwrapCurryLevels(AnyFunctionType *fnTy,
/// ==> pushes {A, C} to `paramTypes`.
///
/// functionType = (A, B) -> (C, D) -> R
/// if "A", "C", and "D" are in the set,
/// ==> pushes {A, C, D} to `paramTypes`.
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is true,
/// or pushes {C, D, A} otherwise.
///
/// functionType = (Self) -> (A, B, C) -> R
/// if "Self" and "C" are in the set,
/// ==> pushes {Self, C} to `paramTypes`.
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is true,
/// or pushes {C, Self} otherwise.
///
void AutoDiffParameterIndices::getSubsetParameterTypes(
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes) const {
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes,
bool reverseCurryLevels) const {
SmallVector<AnyFunctionType *, 2> curryLevels;
unwrapCurryLevels(functionType, curryLevels);

Expand All @@ -159,6 +185,13 @@ void AutoDiffParameterIndices::getSubsetParameterTypes(
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
}

// If `reverseCurryLevels` is true, reverse the curry levels and offsets.
if (reverseCurryLevels) {
std::reverse(curryLevels.begin(), curryLevels.end());
std::reverse(curryLevelParameterIndexOffsets.begin(),
curryLevelParameterIndexOffsets.end());
}

for (unsigned curryLevelIndex : indices(curryLevels)) {
auto *curryLevel = curryLevels[curryLevelIndex];
unsigned parameterIndexOffset =
Expand Down
54 changes: 49 additions & 5 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4444,7 +4444,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
AutoDiffParameterIndices *indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenSig) {
GenericSignature *whereClauseGenSig, bool makeSelfParamFirst) {
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
Expand All @@ -4460,12 +4460,17 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
auto &ctx = getASTContext();

SmallVector<Type, 8> wrtParamTypes;
indices->getSubsetParameterTypes(this, wrtParamTypes);
indices->getSubsetParameterTypes(
this, wrtParamTypes, /*reverseCurryLevels*/ !makeSelfParamFirst);

// Unwrap curry levels.
// Unwrap curry levels. At most, two parameter lists are necessary, for
// curried method types with a `(Self)` parameter list.
SmallVector<AnyFunctionType *, 2> curryLevels;
auto *currentLevel = this->eraseDynamicSelfType()->castTo<AnyFunctionType>();
while (currentLevel != nullptr) {
auto *currentLevel = eraseDynamicSelfType()->castTo<AnyFunctionType>();
for (unsigned i : range(2)) {
(void)i;
if (currentLevel == nullptr)
break;
curryLevels.push_back(currentLevel);
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
}
Expand Down Expand Up @@ -4566,6 +4571,45 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
return associatedFunction;
}

// SWIFT_ENABLE_TENSORFLOW
// Compute the original function type corresponding to the given derivative
// function type.
AnyFunctionType *
AnyFunctionType::getAutoDiffOriginalFunctionType() {
// Unwrap curry levels. At most, two parameter lists are necessary, for
// curried method types with a `(Self)` parameter list.
SmallVector<AnyFunctionType *, 2> curryLevels;
auto *currentLevel = this;
for (unsigned i : range(2)) {
(void)i;
if (currentLevel == nullptr)
break;
curryLevels.push_back(currentLevel);
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
}

auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
"Expected derivative result to be a two-element tuple");
auto originalResult = derivativeResult->getElement(0).getType();
auto *originalType = makeFunctionType(
curryLevels.back(), curryLevels.back()->getParams(), originalResult,
curryLevels.size() == 1 ? getOptGenericSignature() : nullptr);

// Wrap the associated function type in additional curry levels.
auto curryLevelsWithoutLast =
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
for (auto pair : enumerate(reversed(curryLevelsWithoutLast))) {
unsigned i = pair.index();
AnyFunctionType *curryLevel = pair.value();
originalType = makeFunctionType(
curryLevel, curryLevel->getParams(), originalType,
i == curryLevelsWithoutLast.size() - 1 ? getOptGenericSignature()
: nullptr);
}
return originalType;
}

AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const {
SmallVector<Param, 8> newParams;
for (auto &param : getParams()) {
Expand Down
16 changes: 2 additions & 14 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
parameterIndices->contains(index);
};

// Calculate WRT parameter infos, in the order that they should appear in the
// results/parameters of the differential/pullback.
// Calculate differentiation parameter infos.
SmallVector<SILParameterInfo, 4> wrtParams;
// Make the self parameter appear first in the results/parameters of the
// differntial/pullback, even though it's the last parameter of the original
// method.
if (getExtInfo().hasSelfParam() &&
isWrtIndex(getNumParameters() - 1))
wrtParams.push_back(getParameters()[getNumParameters() - 1]);
for (auto valueAndIndex : enumerate(getParameters())) {
// Skip the self parameter because we have already added it.
if (getExtInfo().hasSelfParam() &&
valueAndIndex.index() == getNumParameters() - 1)
continue;
for (auto valueAndIndex : enumerate(getParameters()))
if (isWrtIndex(valueAndIndex.index()))
wrtParams.push_back(valueAndIndex.value());
}

CanSILFunctionType closureType;
switch (kind) {
Expand Down
54 changes: 54 additions & 0 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,60 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
assert(!F->isExternalDeclaration() && "did not emit any function body?!");
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
F->print(llvm::dbgs()));

// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
if (constant.hasDecl()) {
auto *AFD = constant.getAbstractFunctionDecl();
// Continue only if original function is an instance method.
if (AFD && AFD->isInstanceMember() &&
F->getLoweredFunctionType()->hasSelfParam()) {
// Jointly iterate over AST `@differentiable` attributes and SIL
// `[differentiable]` attributes.
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
auto silDiffAttrs = F->getDifferentiableAttrs();
for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) {
auto *diffAttr = const_cast<DifferentiableAttr *>(std::get<0>(pair));
auto *silDiffAttr = std::get<1>(pair);
// Compute autodiff indices.
auto paramIndices = diffAttr->getParameterIndices();
auto loweredParamIndices = paramIndices->getLowered(
getASTContext(),
AFD->getInterfaceType()->castTo<AnyFunctionType>());
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
assert(silDiffAttr->getIndices() == indices &&
"Expected matching @differentiable and [differentiable]");

// If user-defined JVP/VJP is not differentiable wrt self or is only
// differentiable wrt self, reordering is not necessary. Continue.
auto selfParamIndex =
F->getArgumentsWithoutIndirectResults().size() - 1;
bool isWrtSelf = indices.isWrtParameter(selfParamIndex);
if (!isWrtSelf || indices.parameters->getNumIndices() == 1)
continue;

// Thunk JVP method, if it is defined.
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
auto *thunk = getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
jvpFn->isSerialized());
silDiffAttr->setJVPName(thunk->getName());
// Unset JVP so that TBDGen triggers.
diffAttr->setJVPFunction(nullptr);
}
// Thunk VJP method, if it is defined.
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
auto *thunk = getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
vjpFn->isSerialized());
silDiffAttr->setVJPName(thunk->getName());
// Unset VJP so that TBDGen triggers.
diffAttr->setVJPFunction(nullptr);
}
}
}
}
F->verify();
}

Expand Down
10 changes: 10 additions & 0 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
CanSILFunctionType toType,
CanType dynamicSelfType);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create a thunk for reordering autodiff associated functions with a
/// self parameter, so that self appears as:
/// - The last parameter in the returned differential.
/// - The last result in the returned pullback.
SILFunction *getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
IsSerialized_t isSerialized);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
bool hasNonTrivialIVars(ClassDecl *cd);
Expand Down
13 changes: 13 additions & 0 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,19 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
CanType &dynamicSelfType,
bool withoutActuallyEscaping=false);

// SWIFT_ENABLE_TENSORFLOW
//===--------------------------------------------------------------------===//
// Differentiation thunks
//===--------------------------------------------------------------------===//

/// Get or create a thunk for reordering linear maps that are differentiable
/// wrt self, so that self appears as:
/// - The last parameter in the differential.
/// - The last result in the pullback.
SILFunction *getOrCreateAutoDiffLinearMapReorderingThunk(
AutoDiffAssociatedFunctionKind assocFnKind,
CanSILFunctionType fromType, CanSILFunctionType toType);

//===--------------------------------------------------------------------===//
// NoEscaping to Escaping closure thunk
//===--------------------------------------------------------------------===//
Expand Down
Loading