Skip to content

[AutoDiff] Fix SILGen JVP/VJP thunking bug. #26448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
CanGenericSignature assocFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
AutoDiffIndexSubset *parameterIndices, SILModule &module) {
// If associated function has no
if (!assocFnGenSig)
return nullptr;
auto &ctx = module.getASTContext();
Expand Down
22 changes: 18 additions & 4 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,13 +810,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
indices.parameters, indices.source, /*differentiationOrder*/ 1,
AutoDiffAssociatedFunctionKind::VJP, M, lookUpConformance);

// Self reordering is necessary if wrt at least two parameters, including
// self.
auto shouldReorderSelf = [&]() {
if (!F->hasSelfParam())
return false;
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();

// Thunk JVP method, if it is defined.
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
SILFunction *jvpThunk;
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
if (jvpFn->getLoweredFunctionType() != expectedJVPType) {
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
jvpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP);
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
reorderSelf);
} else {
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
Expand All @@ -831,9 +844,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
SILFunction *vjpThunk;
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
if (vjpFn->getLoweredFunctionType() != expectedVJPType) {
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
vjpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP);
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
reorderSelf);
} else {
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
Expand Down
3 changes: 2 additions & 1 deletion lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// - The last result in the returned pullback.
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind);
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
bool reorderSelf);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down
9 changes: 5 additions & 4 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1780,14 +1780,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
// Differentiation thunks
//===--------------------------------------------------------------------===//

/// Get or create a thunk for reabstracting differentials/pullbacks returned
/// by user-defined JVP/VJP functions.
/// Get or create a thunk for reabstracting and self-reordering
/// differentials/pullbacks returned by user-defined JVP/VJP functions, and
/// apply it to the given differential/pullback.
///
/// If `reorderSelf` is true, reorder self so that it appears as:
/// - The last parameter, for differentials.
/// - The last result, for pullbacks.
SILFunction *getOrCreateAutoDiffLinearMapThunk(
AutoDiffAssociatedFunctionKind assocFnKind,
ManagedValue getThunkedAutoDiffLinearMap(
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the recently added AutoDiffLinearMapEnum:

Suggested change
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
ManagedValue linearMap, AutoDiffLinearMapEnum linearMapKind,

Will address in a follow-up to unblock progress.
This PR fixes latent bug in nightly toolchains.

CanSILFunctionType fromType, CanSILFunctionType toType,
bool reorderSelf);

Expand Down
75 changes: 37 additions & 38 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,6 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
/// tuple. Otherwise, add this value directly to `result`.
static void extractAllElements(SILValue val, SILBuilder &builder,
SmallVectorImpl<SILValue> &result) {
// auto &fn = builder.getFunction();
if (auto tupleType = val->getType().getAs<TupleType>())
for (auto i : range(tupleType->getNumElements()))
result.push_back(builder.createTupleExtract(val.getLoc(), val, i));
Expand All @@ -3385,10 +3384,11 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,

// SWIFT_ENABLE_TENSORFLOW
/// Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
SILFunction *
SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
AutoDiffAssociatedFunctionKind assocFnKind, CanSILFunctionType fromType,
CanSILFunctionType toType, bool reorderSelf) {
ManagedValue
SILGenFunction::getThunkedAutoDiffLinearMap(
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
CanSILFunctionType fromType, CanSILFunctionType toType,
bool reorderSelf) {
// Compute the thunk type.
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
Expand All @@ -3409,7 +3409,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
std::string name = mangler.mangleReabstractionThunkHelper(
thunkType, fromInterfaceType, toInterfaceType,
Type(), getModule().getSwiftModule());
// TODO: Use principled mangling.
// TODO(TF-685): Use principled thunk mangling.
if (reorderSelf) {
switch (assocFnKind) {
case AutoDiffAssociatedFunctionKind::JVP:
Expand All @@ -3428,8 +3428,20 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
auto *thunk = fb.getOrCreateSharedFunction(
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);

// Partially-apply the thunk to `linearMap` and return the thunked value.
auto getThunkedResult = [&]() {
auto thunkedFn = createPartialApplyOfThunk(
*this, loc, thunk, interfaceSubs, dynamicSelfType, toType, linearMap);
if (!toType->isNoEscape())
return thunkedFn;
// Handle escaping to noescape conversion.
return B.createConvertEscapeToNoEscape(
loc, thunkedFn, SILType::getPrimitiveObjectType(toType));
};

if (!thunk->empty())
return thunk;
return getThunkedResult();
thunk->setGenericEnvironment(genericEnv);
thunk->setOwnershipEliminated();

Expand Down Expand Up @@ -3560,9 +3572,9 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
arguments.push_back(load);
}

auto linearMap = thunk->getArgumentsWithoutIndirectResults().back();
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
auto *apply = thunkSGF.B.createApply(
loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false);
loc, linearMapArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false);

// Get return elements.
SmallVector<SILValue, 4> results;
Expand Down Expand Up @@ -3630,7 +3642,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(

// Create return.
thunkSGF.B.createReturn(loc, retVal);
return thunk;
return getThunkedResult();
}

/// Forward function arguments, converting ownership.
Expand Down Expand Up @@ -3687,9 +3699,12 @@ static void forwardFunctionArgumentsConvertingOwnership(
SILFunction *
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) {
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
bool reorderSelf) {
auto assocFnType = assocFn->getLoweredFunctionType();

// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
Mangle::ASTMangler mangler;
auto name = getASTContext().getIdentifier(
mangler.mangleAutoDiffAssociatedFunctionHelper(
Expand Down Expand Up @@ -3746,8 +3761,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(

SmallVector<SILValue, 8> directResults;
extractAllElements(apply, thunkSGF.B, directResults);
auto linearMap = directResults.back();
auto linearMapFnType = linearMap->getType().castTo<SILFunctionType>();
auto linearMap = ManagedValue::forBorrowedObjectRValue(directResults.back());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You aren’t using any SILGen memory management in this thunk. This should be ‘ManagedValue::forUnmanaged’ or something close (I forgot the exact name).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address shortly in a follow-up.
(I plan to enable ownership for AD associated function thunks soon.)

auto linearMapFnType = linearMap.getType().castTo<SILFunctionType>();
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
origAssocFnType->getResults().back().getSILStorageType())
.castTo<SILFunctionType>();
Expand All @@ -3769,33 +3784,17 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
thunkSGF.B.createReturn(loc, retValue);
};

// If linear map types are unchanged, return the `apply` instruction.
if (linearMapFnType == targetLinearMapFnType) {
// If self ordering is not necessary and linear map types are unchanged,
// return the `apply` instruction.
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
createReturn(apply);
return thunk;
}

// Generate linear map thunk for reabstraction/self reordering.
auto shouldReorderSelf = [&]() {
if (!original->hasSelfParam())
return false;
auto selfParamIndex =
original->getArgumentsWithoutIndirectResults().size() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();
auto *linearMapThunk = thunkSGF.getOrCreateAutoDiffLinearMapThunk(
assocFnKind, linearMapFnType, targetLinearMapFnType, reorderSelf);
auto linearMapThunkValue =
thunkSGF.B.createFunctionRefFor(loc, linearMapThunk);
SubstitutionMap linearMapSubs;
if (linearMapThunk->getLoweredFunctionType()->isPolymorphic())
linearMapSubs = thunk->getForwardingSubstitutionMap();
linearMap = thunkSGF.B.createPartialApply(
loc, linearMapThunkValue, linearMapSubs, {linearMap},
linearMapFnType->getCalleeConvention());
// Otherwise, apply reabstraction/self reordering thunk to linear map.
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType,
reorderSelf);

// Return original results and thunked differential/pullback.
if (directResults.size() > 1) {
Expand All @@ -3804,10 +3803,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
auto originalDirectResult =
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc());
auto thunkResult = joinElements(
{originalDirectResult, linearMap}, thunkSGF.B, loc);
{originalDirectResult, linearMap.getValue()}, thunkSGF.B, loc);
createReturn(thunkResult);
} else {
createReturn(linearMap);
createReturn(linearMap.getValue());
}
return thunk;
}
Expand Down
98 changes: 49 additions & 49 deletions lib/SILGen/SILGenThunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,27 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef assocFnDeclRef,
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
originalLinkage, /*isAssocFnExported*/ true);
auto name = assocFnDeclRef.mangle();
auto *F = builder.getOrCreateFunction(
auto *thunk = builder.getOrCreateFunction(
assocFnDecl, name, linkage, assocFnTy, IsBare, IsTransparent,
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);

if (F->empty()) {
if (auto genSig = assocFnTy->getGenericSignature())
F->setGenericEnvironment(genSig->createGenericEnvironment());
SILGenFunction SGF(*this, *F, SwiftModule);
SmallVector<ManagedValue, 4> params;
auto loc = assocFnDeclRef.getAsRegularLocation();
SGF.collectThunkParams(loc, params);
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
SmallVector<SILValue, 4> args(F->getArguments().begin(),
F->getArguments().end());
auto apply = SGF.emitApplyWithRethrow(
loc, assocFnRef, autoDiffAssocFnSILTy,
SGF.getForwardingSubstitutionMap(), args);
SGF.B.createReturn(loc, apply);
}

return F;
if (!thunk->empty())
return thunk;

if (auto genSig = assocFnTy->getGenericSignature())
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
SILGenFunction SGF(*this, *thunk, SwiftModule);
SmallVector<ManagedValue, 4> params;
auto loc = assocFnDeclRef.getAsRegularLocation();
SGF.collectThunkParams(loc, params);
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
thunk->getArguments().end());
auto apply = SGF.emitApplyWithRethrow(
loc, assocFnRef, autoDiffAssocFnSILTy,
SGF.getForwardingSubstitutionMap(), args);
SGF.B.createReturn(loc, apply);
return thunk;
}

// SWIFT_ENABLE_TENSORFLOW
Expand All @@ -120,37 +119,38 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
auto originalLinkage = originalFn.getLinkage(ForDefinition);
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
originalLinkage, /*isAssocFnExported*/ true);
auto name = assocFnDeclRef.mangle() + "_thunk";
auto *F = builder.getOrCreateFunction(
// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
auto name = assocFnDeclRef.mangle() + "_vtable_entry_thunk";
auto *thunk = builder.getOrCreateFunction(
assocFnDecl, name, linkage, constantTy, IsBare, IsTransparent,
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);

if (F->empty()) {
if (auto genSig = constantTy->getGenericSignature())
F->setGenericEnvironment(genSig->createGenericEnvironment());
SILGenFunction SGF(*this, *F, SwiftModule);
SmallVector<ManagedValue, 4> params;
auto loc = assocFnDeclRef.getAsRegularLocation();
SGF.collectThunkParams(loc, params);
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
SGF.getASTContext(),
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
auto autoDiffFn = SGF.B.createAutoDiffFunction(
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
/*differentiationOrder*/ 1, autoDiffFn);
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
SmallVector<SILValue, 4> args(F->getArguments().begin(),
F->getArguments().end());
auto apply = SGF.emitApplyWithRethrow(
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
SGF.getForwardingSubstitutionMap(), args);
SGF.B.createReturn(loc, apply);
}

return F;
if (!thunk->empty())
return thunk;

if (auto genSig = constantTy->getGenericSignature())
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
SILGenFunction SGF(*this, *thunk, SwiftModule);
SmallVector<ManagedValue, 4> params;
auto loc = assocFnDeclRef.getAsRegularLocation();
SGF.collectThunkParams(loc, params);
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
SGF.getASTContext(),
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
auto autoDiffFn = SGF.B.createAutoDiffFunction(
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
/*differentiationOrder*/ 1, autoDiffFn);
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
thunk->getArguments().end());
auto apply = SGF.emitApplyWithRethrow(
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
SGF.getForwardingSubstitutionMap(), args);
SGF.B.createReturn(loc, apply);
return thunk;
}

ManagedValue
Expand Down
18 changes: 0 additions & 18 deletions test/AutoDiff/method_self_reordering_thunk/main.swift

This file was deleted.

Loading