Skip to content

[AutoDiff] Fix forward-mode crashes related to tangent buffers #33633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 9, 2020
Merged
195 changes: 141 additions & 54 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,6 @@ class JVPCloner::Implementation final
// General utilities
//--------------------------------------------------------------------------//

SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() {
// If there are no local allocations, insert at the beginning of the tangent
// entry.
if (differentialLocalAllocations.empty())
return getDifferential().getEntryBlock()->begin();
// Otherwise, insert before the last local allocation. Inserting before
// rather than after ensures that allocation and zero initialization
// instructions are grouped together.
auto lastLocalAlloc = differentialLocalAllocations.back();
auto it = lastLocalAlloc->getDefiningInstruction()->getIterator();
return it;
}

/// Get the lowered SIL type of the given AST type.
SILType getLoweredType(Type type) {
auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature();
Expand Down Expand Up @@ -309,6 +296,8 @@ class JVPCloner::Implementation final
// Tangent buffer mapping
//--------------------------------------------------------------------------//

/// Sets the tangent buffer for the original buffer. Asserts that the
/// original buffer does not already have a tangent buffer.
void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
SILValue tangentBuffer) {
assert(originalBuffer->getType().isAddress());
Expand All @@ -318,13 +307,14 @@ class JVPCloner::Implementation final
(void)insertion;
}

/// Returns the tangent buffer for the original buffer. Asserts that the
/// original buffer has a tangent buffer.
SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
assert(originalBuffer->getType().isAddress());
assert(originalBuffer->getFunction() == original);
auto insertion =
bufferMap.try_emplace({origBB, originalBuffer}, SILValue());
assert(!insertion.second && "Tangent buffer should already exist");
return insertion.first->getSecond();
auto it = bufferMap.find({origBB, originalBuffer});
assert(it != bufferMap.end() && "Tangent buffer should already exist");
return it->getSecond();
}

//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -446,9 +436,21 @@ class JVPCloner::Implementation final
// If an `apply` has active results or active inout parameters, replace it
// with an `apply` of its JVP.
void visitApplyInst(ApplyInst *ai) {
bool shouldDifferentiate =
differentialInfo.shouldDifferentiateApplySite(ai);
// If the function has no active arguments or results, zero-initialize the
// tangent buffers of the active indirect results.
if (!shouldDifferentiate) {
for (auto indResult : ai->getIndirectSILResults())
if (activityInfo.isActive(indResult, getIndices())) {
auto &tanBuf = getTangentBuffer(ai->getParent(), indResult);
emitZeroIndirect(tanBuf->getType().getASTType(), tanBuf,
tanBuf.getLoc());
}
}
// If the function should not be differentiated or its the array literal
// initialization intrinsic, just do standard cloning.
if (!differentialInfo.shouldDifferentiateApplySite(ai) ||
if (!shouldDifferentiate ||
ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
TypeSubstCloner::visitApplyInst(ai);
Expand Down Expand Up @@ -779,7 +781,7 @@ class JVPCloner::Implementation final
auto &diffBuilder = getDifferentialBuilder();
auto loc = dvi->getLoc();
auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc);
diffBuilder.emitDestroyValue(loc, tanVal);
diffBuilder.emitDestroyValueOperation(loc, tanVal);
}

CLONE_AND_EMIT_TANGENT(CopyValue, cvi) {
Expand All @@ -794,7 +796,20 @@ class JVPCloner::Implementation final
/// Handle `load` instruction.
/// Original: y = load x
/// Tangent: tan[y] = load tan[x]
CLONE_AND_EMIT_TANGENT(Load, li) {
void visitLoadInst(LoadInst *li) {
TypeSubstCloner::visitLoadInst(li);
// If an active buffer is loaded with take to a non-active value, destroy
// the active buffer's tangent buffer.
if (!differentialInfo.shouldDifferentiateInstruction(li)) {
auto isTake =
(li->getOwnershipQualifier() == LoadOwnershipQualifier::Take);
if (isTake && activityInfo.isActive(li->getOperand(), getIndices())) {
auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand());
getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf);
}
return;
}
// Otherwise, do standard differential cloning.
auto &diffBuilder = getDifferentialBuilder();
auto *bb = li->getParent();
auto loc = li->getLoc();
Expand All @@ -819,7 +834,19 @@ class JVPCloner::Implementation final
/// Handle `store` instruction in the differential.
/// Original: store x to y
/// Tangent: store tan[x] to tan[y]
CLONE_AND_EMIT_TANGENT(Store, si) {
void visitStoreInst(StoreInst *si) {
TypeSubstCloner::visitStoreInst(si);
// If a non-active value is stored into an active buffer, zero-initialize
// the active buffer's tangent buffer.
if (!differentialInfo.shouldDifferentiateInstruction(si)) {
if (activityInfo.isActive(si->getDest(), getIndices())) {
auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest());
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
tanBufDest.getLoc());
}
return;
}
// Otherwise, do standard differential cloning.
auto &diffBuilder = getDifferentialBuilder();
auto loc = si->getLoc();
auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
Expand All @@ -831,7 +858,19 @@ class JVPCloner::Implementation final
/// Handle `store_borrow` instruction in the differential.
/// Original: store_borrow x to y
/// Tangent: store_borrow tan[x] to tan[y]
CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) {
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
TypeSubstCloner::visitStoreBorrowInst(sbi);
// If a non-active value is stored into an active buffer, zero-initialize
// the active buffer's tangent buffer.
if (!differentialInfo.shouldDifferentiateInstruction(sbi)) {
if (activityInfo.isActive(sbi->getDest(), getIndices())) {
auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
tanBufDest.getLoc());
}
return;
}
// Otherwise, do standard differential cloning.
auto &diffBuilder = getDifferentialBuilder();
auto loc = sbi->getLoc();
auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
Expand All @@ -842,13 +881,32 @@ class JVPCloner::Implementation final
/// Handle `copy_addr` instruction.
/// Original: copy_addr x to y
/// Tangent: copy_addr tan[x] to tan[y]
CLONE_AND_EMIT_TANGENT(CopyAddr, cai) {
void visitCopyAddrInst(CopyAddrInst *cai) {
TypeSubstCloner::visitCopyAddrInst(cai);
// If a non-active buffer is copied into an active buffer, zero-initialize
// the destination buffer's tangent buffer.
// If an active buffer is copied with take into a non-active buffer, destroy
// the source buffer's tangent buffer.
if (!differentialInfo.shouldDifferentiateInstruction(cai)) {
if (activityInfo.isActive(cai->getDest(), getIndices())) {
auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest());
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
tanBufDest.getLoc());
}
if (cai->isTakeOfSrc() &&
activityInfo.isActive(cai->getSrc(), getIndices())) {
auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc());
getDifferentialBuilder().emitDestroyOperation(tanBufSrc.getLoc(),
tanBufSrc);
}
return;
}
// Otherwise, do standard differential cloning.
auto diffBuilder = getDifferentialBuilder();
auto loc = cai->getLoc();
auto *bb = cai->getParent();
auto &tanSrc = getTangentBuffer(bb, cai->getSrc());
auto tanDest = getTangentBuffer(bb, cai->getDest());

diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(),
cai->isInitializationOfDest());
}
Expand Down Expand Up @@ -908,8 +966,8 @@ class JVPCloner::Implementation final
auto &diffBuilder = getDifferentialBuilder();
auto *bb = eai->getParent();
auto loc = eai->getLoc();
auto tanSrc = getTangentBuffer(bb, eai->getOperand());
diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting());
auto tanOperand = getTangentBuffer(bb, eai->getOperand());
diffBuilder.createEndAccess(loc, tanOperand, eai->isAborting());
}

/// Handle `alloc_stack` instruction.
Expand All @@ -920,7 +978,7 @@ class JVPCloner::Implementation final
auto *mappedAllocStackInst = diffBuilder.createAllocStack(
asi->getLoc(), getRemappedTangentType(asi->getElementType()),
asi->getVarInfo());
bufferMap.try_emplace({asi->getParent(), asi}, mappedAllocStackInst);
setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst);
}

/// Handle `dealloc_stack` instruction.
Expand Down Expand Up @@ -1052,16 +1110,15 @@ class JVPCloner::Implementation final
auto tanType = getRemappedTangentType(tei->getType());
auto tanSource =
materializeTangent(getTangentValue(tei->getOperand()), loc);
SILValue tanBuf;
// If the tangent buffer of the source does not have a tuple type, then
// If the tangent value of the source does not have a tuple type, then
// it must represent a "single element tuple type". Use it directly.
if (!tanSource->getType().is<TupleType>()) {
setTangentValue(tei->getParent(), tei,
makeConcreteTangentValue(tanSource));
} else {
tanBuf =
auto tanElt =
diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType);
bufferMap.try_emplace({tei->getParent(), tei}, tanBuf);
setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanElt));
}
}

Expand Down Expand Up @@ -1090,7 +1147,7 @@ class JVPCloner::Implementation final
tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource,
tanIndex, tanType);
}
bufferMap.try_emplace({teai->getParent(), teai}, tanBuf);
setTangentBuffer(teai->getParent(), teai, tanBuf);
}

/// Handle `destructure_tuple` instruction.
Expand Down Expand Up @@ -1272,9 +1329,8 @@ class JVPCloner::Implementation final
// Collect original results.
SmallVector<SILValue, 2> originalResults;
collectAllDirectResultsInTypeOrder(*original, originalResults);
// Collect differential return elements.
// Collect differential direct results.
SmallVector<SILValue, 8> retElts;
// for (auto origResult : originalResults) {
for (auto i : range(originalResults.size())) {
auto origResult = originalResults[i];
if (!getIndices().results->contains(i))
Expand Down Expand Up @@ -1391,7 +1447,10 @@ JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
void JVPCloner::Implementation::prepareForDifferentialGeneration() {
// Create differential blocks and arguments.
auto &differential = getDifferential();
auto diffLoc = differential.getLocation();
auto *origEntry = original->getEntryBlock();
auto origFnTy = original->getLoweredFunctionType();

for (auto &origBB : *original) {
auto *diffBB = differential.createBasicBlock();
diffBBMap.insert({&origBB, diffBB});
Expand Down Expand Up @@ -1472,21 +1531,51 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
<< " as the tangent of original result " << *origArg);
}

// Initialize tangent mapping for indirect results.
auto origIndResults = original->getIndirectResults();
// Initialize tangent mapping for original indirect results and non-wrt
// `inout` parameters. The tangent buffers of these address values are
// differential indirect results.

// Collect original results.
SmallVector<SILValue, 2> originalResults;
collectAllFormalResultsInTypeOrder(*original, originalResults);

// Iterate over differentiability results.
differentialBuilder.setInsertionPoint(differential.getEntryBlock());
auto diffIndResults = differential.getIndirectResults();
#ifndef NDEBUG
unsigned numNonWrtInoutParameters = llvm::count_if(
range(original->getLoweredFunctionType()->getNumParameters()),
[&] (unsigned i) {
auto &paramInfo = original->getLoweredFunctionType()->getParameters()[i];
return paramInfo.isIndirectInOut() && !getIndices().parameters->contains(i);
});
#endif
assert(origIndResults.size() + numNonWrtInoutParameters == diffIndResults.size());
for (auto &origBB : *original)
for (auto i : indices(origIndResults))
setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
unsigned differentialIndirectResultIndex = 0;
for (auto resultIndex : getIndices().results->getIndices()) {
auto origResult = originalResults[resultIndex];
// Handle original formal indirect result.
if (resultIndex < origFnTy->getNumResults()) {
// Skip original direct results.
if (origResult->getType().isObject())
continue;
auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
setTangentBuffer(origEntry, origResult, diffIndResult);
// If original indirect result is non-varied, zero-initialize its tangent
// buffer.
if (!activityInfo.isVaried(origResult, getIndices().parameters))
emitZeroIndirect(diffIndResult->getType().getASTType(),
diffIndResult, diffLoc);
continue;
}
// Handle original non-wrt `inout` parameter.
// Only original *non-wrt* `inout` parameters have corresponding
// differential indirect results.
auto inoutParamIndex = resultIndex - origFnTy->getNumResults();
auto inoutParamIt = std::next(
origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
auto paramIndex =
std::distance(origFnTy->getParameters().begin(), &*inoutParamIt);
if (getIndices().parameters->contains(paramIndex))
continue;
auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
setTangentBuffer(origEntry, origResult, diffIndResult);
// Original `inout` parameters are initialized, so their tangent buffers
// must also be initialized.
emitZeroIndirect(diffIndResult->getType().getASTType(),
diffIndResult, diffLoc);
}
}

/*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential(
Expand Down Expand Up @@ -1516,7 +1605,6 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
auto origParams = origTy->getParameters();
auto indices = witness->getSILAutoDiffIndices();


for (auto resultIndex : indices.results->getIndices()) {
if (resultIndex < origTy->getNumResults()) {
// Handle formal original result.
Expand All @@ -1529,17 +1617,16 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
->getType()
->getCanonicalType(witnessCanGenSig),
origResult.getConvention()));
}
else {
} else {
// Handle original `inout` parameter.
auto inoutParamIndex = resultIndex - origTy->getNumResults();
auto inoutParamIt = std::next(
origTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
auto paramIndex =
std::distance(origTy->getParameters().begin(), &*inoutParamIt);
// If the original `inout` parameter is a differentiability parameter, then
// it already has a corresponding differential parameter. Skip adding a
// corresponding differential result.
// If the original `inout` parameter is a differentiability parameter,
// then it already has a corresponding differential parameter. Do not add
// a corresponding differential result.
if (indices.parameters->contains(paramIndex))
continue;
auto inoutParam = origTy->getParameters()[paramIndex];
Expand Down
3 changes: 2 additions & 1 deletion lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
/// 3. The instruction has both an active result (direct or indirect) and an
/// active argument.
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
// Function applications with an inout argument should be differentiated.
// Function applications with an active inout argument should be
// differentiated.
for (auto inoutArg : applySite.getInoutArguments())
if (activityInfo.isActive(inoutArg, indices))
return true;
Expand Down
Loading