-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Fix SILGen JVP/VJP thunking bug. #26448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3364,7 +3364,6 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType( | |
/// tuple. Otherwise, add this value directly to `result`. | ||
static void extractAllElements(SILValue val, SILBuilder &builder, | ||
SmallVectorImpl<SILValue> &result) { | ||
// auto &fn = builder.getFunction(); | ||
if (auto tupleType = val->getType().getAs<TupleType>()) | ||
for (auto i : range(tupleType->getNumElements())) | ||
result.push_back(builder.createTupleExtract(val.getLoc(), val, i)); | ||
|
@@ -3385,10 +3384,11 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, | |
|
||
// SWIFT_ENABLE_TENSORFLOW | ||
/// Adapted from `SILGenModule::getOrCreateReabstractionThunk`. | ||
SILFunction * | ||
SILGenFunction::getOrCreateAutoDiffLinearMapThunk( | ||
AutoDiffAssociatedFunctionKind assocFnKind, CanSILFunctionType fromType, | ||
CanSILFunctionType toType, bool reorderSelf) { | ||
ManagedValue | ||
SILGenFunction::getThunkedAutoDiffLinearMap( | ||
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind, | ||
CanSILFunctionType fromType, CanSILFunctionType toType, | ||
bool reorderSelf) { | ||
// Compute the thunk type. | ||
SubstitutionMap interfaceSubs; | ||
GenericEnvironment *genericEnv = nullptr; | ||
|
@@ -3409,7 +3409,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk( | |
std::string name = mangler.mangleReabstractionThunkHelper( | ||
thunkType, fromInterfaceType, toInterfaceType, | ||
Type(), getModule().getSwiftModule()); | ||
// TODO: Use principled mangling. | ||
// TODO(TF-685): Use principled thunk mangling. | ||
if (reorderSelf) { | ||
switch (assocFnKind) { | ||
case AutoDiffAssociatedFunctionKind::JVP: | ||
|
@@ -3428,8 +3428,20 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk( | |
auto *thunk = fb.getOrCreateSharedFunction( | ||
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized, | ||
ProfileCounter(), IsReabstractionThunk, IsNotDynamic); | ||
|
||
// Partially-apply the thunk to `linearMap` and return the thunked value. | ||
auto getThunkedResult = [&]() { | ||
auto thunkedFn = createPartialApplyOfThunk( | ||
*this, loc, thunk, interfaceSubs, dynamicSelfType, toType, linearMap); | ||
if (!toType->isNoEscape()) | ||
return thunkedFn; | ||
// Handle escaping to noescape conversion. | ||
return B.createConvertEscapeToNoEscape( | ||
loc, thunkedFn, SILType::getPrimitiveObjectType(toType)); | ||
}; | ||
|
||
if (!thunk->empty()) | ||
return thunk; | ||
return getThunkedResult(); | ||
thunk->setGenericEnvironment(genericEnv); | ||
thunk->setOwnershipEliminated(); | ||
|
||
|
@@ -3560,9 +3572,9 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk( | |
arguments.push_back(load); | ||
} | ||
|
||
auto linearMap = thunk->getArgumentsWithoutIndirectResults().back(); | ||
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back(); | ||
auto *apply = thunkSGF.B.createApply( | ||
loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false); | ||
loc, linearMapArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false); | ||
|
||
// Get return elements. | ||
SmallVector<SILValue, 4> results; | ||
|
@@ -3630,7 +3642,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk( | |
|
||
// Create return. | ||
thunkSGF.B.createReturn(loc, retVal); | ||
return thunk; | ||
return getThunkedResult(); | ||
} | ||
|
||
/// Forward function arguments, converting ownership. | ||
|
@@ -3687,9 +3699,12 @@ static void forwardFunctionArgumentsConvertingOwnership( | |
SILFunction * | ||
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( | ||
SILFunction *original, SILAutoDiffIndices &indices, | ||
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) { | ||
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind, | ||
bool reorderSelf) { | ||
auto assocFnType = assocFn->getLoweredFunctionType(); | ||
|
||
// TODO(TF-685): Use principled thunk mangling. | ||
// Do not simply reuse reabstraction thunk mangling. | ||
Mangle::ASTMangler mangler; | ||
auto name = getASTContext().getIdentifier( | ||
mangler.mangleAutoDiffAssociatedFunctionHelper( | ||
|
@@ -3746,8 +3761,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( | |
|
||
SmallVector<SILValue, 8> directResults; | ||
extractAllElements(apply, thunkSGF.B, directResults); | ||
auto linearMap = directResults.back(); | ||
auto linearMapFnType = linearMap->getType().castTo<SILFunctionType>(); | ||
auto linearMap = ManagedValue::forBorrowedObjectRValue(directResults.back()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You aren’t using any SILGen memory management in this thunk. This should be ‘ManagedValue::forUnmanaged’ or something close (I forgot the exact name). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will address shortly in a follow-up. |
||
auto linearMapFnType = linearMap.getType().castTo<SILFunctionType>(); | ||
auto targetLinearMapFnType = thunk->mapTypeIntoContext( | ||
origAssocFnType->getResults().back().getSILStorageType()) | ||
.castTo<SILFunctionType>(); | ||
dan-zheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -3769,33 +3784,17 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( | |
thunkSGF.B.createReturn(loc, retValue); | ||
}; | ||
|
||
// If linear map types are unchanged, return the `apply` instruction. | ||
if (linearMapFnType == targetLinearMapFnType) { | ||
// If self ordering is not necessary and linear map types are unchanged, | ||
// return the `apply` instruction. | ||
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) { | ||
createReturn(apply); | ||
return thunk; | ||
} | ||
|
||
// Generate linear map thunk for reabstraction/self reordering. | ||
auto shouldReorderSelf = [&]() { | ||
if (!original->hasSelfParam()) | ||
return false; | ||
auto selfParamIndex = | ||
original->getArgumentsWithoutIndirectResults().size() - 1; | ||
if (!indices.isWrtParameter(selfParamIndex)) | ||
return false; | ||
return indices.parameters->getNumIndices() > 1; | ||
}; | ||
bool reorderSelf = shouldReorderSelf(); | ||
auto *linearMapThunk = thunkSGF.getOrCreateAutoDiffLinearMapThunk( | ||
assocFnKind, linearMapFnType, targetLinearMapFnType, reorderSelf); | ||
auto linearMapThunkValue = | ||
thunkSGF.B.createFunctionRefFor(loc, linearMapThunk); | ||
SubstitutionMap linearMapSubs; | ||
if (linearMapThunk->getLoweredFunctionType()->isPolymorphic()) | ||
linearMapSubs = thunk->getForwardingSubstitutionMap(); | ||
linearMap = thunkSGF.B.createPartialApply( | ||
loc, linearMapThunkValue, linearMapSubs, {linearMap}, | ||
linearMapFnType->getCalleeConvention()); | ||
// Otherwise, apply reabstraction/self reordering thunk to linear map. | ||
linearMap = thunkSGF.getThunkedAutoDiffLinearMap( | ||
linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType, | ||
reorderSelf); | ||
|
||
// Return original results and thunked differential/pullback. | ||
if (directResults.size() > 1) { | ||
|
@@ -3804,10 +3803,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk( | |
auto originalDirectResult = | ||
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc()); | ||
auto thunkResult = joinElements( | ||
{originalDirectResult, linearMap}, thunkSGF.B, loc); | ||
{originalDirectResult, linearMap.getValue()}, thunkSGF.B, loc); | ||
createReturn(thunkResult); | ||
} else { | ||
createReturn(linearMap); | ||
createReturn(linearMap.getValue()); | ||
} | ||
return thunk; | ||
} | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the recently added
AutoDiffLinearMapEnum
:Will address in a follow-up to unblock progress.
This PR fixes latent bug in nightly toolchains.