Skip to content

Commit ce56277

Browse files
dan-zhengefremale
andauthored
[AutoDiff] Fix forward-mode crashes related to tangent buffers (#33868)
Fixes foward-mode crashes related to: - Missing tangent buffers for non-wrt `inout` parameters. - Tangent buffers not being initialized due to the corresponding original buffer intialization instructions being non-active. - Non-varied indirect results not being initialized. - `emitDestroyValue` crashes due to `TangentVector` value category mismatch. Resolves TF-984 and SR-13447. Co-authored-by: Alex Efremov <[email protected]>
1 parent 39c6ee1 commit ce56277

File tree

6 files changed

+228
-64
lines changed

6 files changed

+228
-64
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 141 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,6 @@ class JVPCloner::Implementation final
134134
// General utilities
135135
//--------------------------------------------------------------------------//
136136

137-
SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() {
138-
// If there are no local allocations, insert at the beginning of the tangent
139-
// entry.
140-
if (differentialLocalAllocations.empty())
141-
return getDifferential().getEntryBlock()->begin();
142-
// Otherwise, insert before the last local allocation. Inserting before
143-
// rather than after ensures that allocation and zero initialization
144-
// instructions are grouped together.
145-
auto lastLocalAlloc = differentialLocalAllocations.back();
146-
auto it = lastLocalAlloc->getDefiningInstruction()->getIterator();
147-
return it;
148-
}
149-
150137
/// Get the lowered SIL type of the given AST type.
151138
SILType getLoweredType(Type type) {
152139
auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature();
@@ -309,6 +296,8 @@ class JVPCloner::Implementation final
309296
// Tangent buffer mapping
310297
//--------------------------------------------------------------------------//
311298

299+
/// Sets the tangent buffer for the original buffer. Asserts that the
300+
/// original buffer does not already have a tangent buffer.
312301
void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
313302
SILValue tangentBuffer) {
314303
assert(originalBuffer->getType().isAddress());
@@ -318,13 +307,14 @@ class JVPCloner::Implementation final
318307
(void)insertion;
319308
}
320309

310+
/// Returns the tangent buffer for the original buffer. Asserts that the
311+
/// original buffer has a tangent buffer.
321312
SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
322313
assert(originalBuffer->getType().isAddress());
323314
assert(originalBuffer->getFunction() == original);
324-
auto insertion =
325-
bufferMap.try_emplace({origBB, originalBuffer}, SILValue());
326-
assert(!insertion.second && "Tangent buffer should already exist");
327-
return insertion.first->getSecond();
315+
auto it = bufferMap.find({origBB, originalBuffer});
316+
assert(it != bufferMap.end() && "Tangent buffer should already exist");
317+
return it->getSecond();
328318
}
329319

330320
//--------------------------------------------------------------------------//
@@ -446,9 +436,21 @@ class JVPCloner::Implementation final
446436
// If an `apply` has active results or active inout parameters, replace it
447437
// with an `apply` of its JVP.
448438
void visitApplyInst(ApplyInst *ai) {
439+
bool shouldDifferentiate =
440+
differentialInfo.shouldDifferentiateApplySite(ai);
441+
// If the function has no active arguments or results, zero-initialize the
442+
// tangent buffers of the active indirect results.
443+
if (!shouldDifferentiate) {
444+
for (auto indResult : ai->getIndirectSILResults())
445+
if (activityInfo.isActive(indResult, getIndices())) {
446+
auto &tanBuf = getTangentBuffer(ai->getParent(), indResult);
447+
emitZeroIndirect(tanBuf->getType().getASTType(), tanBuf,
448+
tanBuf.getLoc());
449+
}
450+
}
449451
// If the function should not be differentiated or its the array literal
450452
// initialization intrinsic, just do standard cloning.
451-
if (!differentialInfo.shouldDifferentiateApplySite(ai) ||
453+
if (!shouldDifferentiate ||
452454
ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
453455
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
454456
TypeSubstCloner::visitApplyInst(ai);
@@ -789,7 +791,7 @@ class JVPCloner::Implementation final
789791
auto &diffBuilder = getDifferentialBuilder();
790792
auto loc = dvi->getLoc();
791793
auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc);
792-
diffBuilder.emitDestroyValue(loc, tanVal);
794+
diffBuilder.emitDestroyValueOperation(loc, tanVal);
793795
}
794796

795797
CLONE_AND_EMIT_TANGENT(CopyValue, cvi) {
@@ -804,7 +806,20 @@ class JVPCloner::Implementation final
804806
/// Handle `load` instruction.
805807
/// Original: y = load x
806808
/// Tangent: tan[y] = load tan[x]
807-
CLONE_AND_EMIT_TANGENT(Load, li) {
809+
void visitLoadInst(LoadInst *li) {
810+
TypeSubstCloner::visitLoadInst(li);
811+
// If an active buffer is loaded with take to a non-active value, destroy
812+
// the active buffer's tangent buffer.
813+
if (!differentialInfo.shouldDifferentiateInstruction(li)) {
814+
auto isTake =
815+
(li->getOwnershipQualifier() == LoadOwnershipQualifier::Take);
816+
if (isTake && activityInfo.isActive(li->getOperand(), getIndices())) {
817+
auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand());
818+
getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf);
819+
}
820+
return;
821+
}
822+
// Otherwise, do standard differential cloning.
808823
auto &diffBuilder = getDifferentialBuilder();
809824
auto *bb = li->getParent();
810825
auto loc = li->getLoc();
@@ -829,7 +844,19 @@ class JVPCloner::Implementation final
829844
/// Handle `store` instruction in the differential.
830845
/// Original: store x to y
831846
/// Tangent: store tan[x] to tan[y]
832-
CLONE_AND_EMIT_TANGENT(Store, si) {
847+
void visitStoreInst(StoreInst *si) {
848+
TypeSubstCloner::visitStoreInst(si);
849+
// If a non-active value is stored into an active buffer, zero-initialize
850+
// the active buffer's tangent buffer.
851+
if (!differentialInfo.shouldDifferentiateInstruction(si)) {
852+
if (activityInfo.isActive(si->getDest(), getIndices())) {
853+
auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest());
854+
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
855+
tanBufDest.getLoc());
856+
}
857+
return;
858+
}
859+
// Otherwise, do standard differential cloning.
833860
auto &diffBuilder = getDifferentialBuilder();
834861
auto loc = si->getLoc();
835862
auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
@@ -841,7 +868,19 @@ class JVPCloner::Implementation final
841868
/// Handle `store_borrow` instruction in the differential.
842869
/// Original: store_borrow x to y
843870
/// Tangent: store_borrow tan[x] to tan[y]
844-
CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) {
871+
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
872+
TypeSubstCloner::visitStoreBorrowInst(sbi);
873+
// If a non-active value is stored into an active buffer, zero-initialize
874+
// the active buffer's tangent buffer.
875+
if (!differentialInfo.shouldDifferentiateInstruction(sbi)) {
876+
if (activityInfo.isActive(sbi->getDest(), getIndices())) {
877+
auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
878+
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
879+
tanBufDest.getLoc());
880+
}
881+
return;
882+
}
883+
// Otherwise, do standard differential cloning.
845884
auto &diffBuilder = getDifferentialBuilder();
846885
auto loc = sbi->getLoc();
847886
auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
@@ -852,13 +891,32 @@ class JVPCloner::Implementation final
852891
/// Handle `copy_addr` instruction.
853892
/// Original: copy_addr x to y
854893
/// Tangent: copy_addr tan[x] to tan[y]
855-
CLONE_AND_EMIT_TANGENT(CopyAddr, cai) {
894+
void visitCopyAddrInst(CopyAddrInst *cai) {
895+
TypeSubstCloner::visitCopyAddrInst(cai);
896+
// If a non-active buffer is copied into an active buffer, zero-initialize
897+
// the destination buffer's tangent buffer.
898+
// If an active buffer is copied with take into a non-active buffer, destroy
899+
// the source buffer's tangent buffer.
900+
if (!differentialInfo.shouldDifferentiateInstruction(cai)) {
901+
if (activityInfo.isActive(cai->getDest(), getIndices())) {
902+
auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest());
903+
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
904+
tanBufDest.getLoc());
905+
}
906+
if (cai->isTakeOfSrc() &&
907+
activityInfo.isActive(cai->getSrc(), getIndices())) {
908+
auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc());
909+
getDifferentialBuilder().emitDestroyOperation(tanBufSrc.getLoc(),
910+
tanBufSrc);
911+
}
912+
return;
913+
}
914+
// Otherwise, do standard differential cloning.
856915
auto diffBuilder = getDifferentialBuilder();
857916
auto loc = cai->getLoc();
858917
auto *bb = cai->getParent();
859918
auto &tanSrc = getTangentBuffer(bb, cai->getSrc());
860919
auto tanDest = getTangentBuffer(bb, cai->getDest());
861-
862920
diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(),
863921
cai->isInitializationOfDest());
864922
}
@@ -918,8 +976,8 @@ class JVPCloner::Implementation final
918976
auto &diffBuilder = getDifferentialBuilder();
919977
auto *bb = eai->getParent();
920978
auto loc = eai->getLoc();
921-
auto tanSrc = getTangentBuffer(bb, eai->getOperand());
922-
diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting());
979+
auto tanOperand = getTangentBuffer(bb, eai->getOperand());
980+
diffBuilder.createEndAccess(loc, tanOperand, eai->isAborting());
923981
}
924982

925983
/// Handle `alloc_stack` instruction.
@@ -930,7 +988,7 @@ class JVPCloner::Implementation final
930988
auto *mappedAllocStackInst = diffBuilder.createAllocStack(
931989
asi->getLoc(), getRemappedTangentType(asi->getElementType()),
932990
asi->getVarInfo());
933-
bufferMap.try_emplace({asi->getParent(), asi}, mappedAllocStackInst);
991+
setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst);
934992
}
935993

936994
/// Handle `dealloc_stack` instruction.
@@ -1062,16 +1120,15 @@ class JVPCloner::Implementation final
10621120
auto tanType = getRemappedTangentType(tei->getType());
10631121
auto tanSource =
10641122
materializeTangent(getTangentValue(tei->getOperand()), loc);
1065-
SILValue tanBuf;
1066-
// If the tangent buffer of the source does not have a tuple type, then
1123+
// If the tangent value of the source does not have a tuple type, then
10671124
// it must represent a "single element tuple type". Use it directly.
10681125
if (!tanSource->getType().is<TupleType>()) {
10691126
setTangentValue(tei->getParent(), tei,
10701127
makeConcreteTangentValue(tanSource));
10711128
} else {
1072-
tanBuf =
1129+
auto tanElt =
10731130
diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType);
1074-
bufferMap.try_emplace({tei->getParent(), tei}, tanBuf);
1131+
setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanElt));
10751132
}
10761133
}
10771134

@@ -1100,7 +1157,7 @@ class JVPCloner::Implementation final
11001157
tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource,
11011158
tanIndex, tanType);
11021159
}
1103-
bufferMap.try_emplace({teai->getParent(), teai}, tanBuf);
1160+
setTangentBuffer(teai->getParent(), teai, tanBuf);
11041161
}
11051162

11061163
/// Handle `destructure_tuple` instruction.
@@ -1282,9 +1339,8 @@ class JVPCloner::Implementation final
12821339
// Collect original results.
12831340
SmallVector<SILValue, 2> originalResults;
12841341
collectAllDirectResultsInTypeOrder(*original, originalResults);
1285-
// Collect differential return elements.
1342+
// Collect differential direct results.
12861343
SmallVector<SILValue, 8> retElts;
1287-
// for (auto origResult : originalResults) {
12881344
for (auto i : range(originalResults.size())) {
12891345
auto origResult = originalResults[i];
12901346
if (!getIndices().results->contains(i))
@@ -1401,7 +1457,10 @@ JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
14011457
void JVPCloner::Implementation::prepareForDifferentialGeneration() {
14021458
// Create differential blocks and arguments.
14031459
auto &differential = getDifferential();
1460+
auto diffLoc = differential.getLocation();
14041461
auto *origEntry = original->getEntryBlock();
1462+
auto origFnTy = original->getLoweredFunctionType();
1463+
14051464
for (auto &origBB : *original) {
14061465
auto *diffBB = differential.createBasicBlock();
14071466
diffBBMap.insert({&origBB, diffBB});
@@ -1482,21 +1541,51 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
14821541
<< " as the tangent of original result " << *origArg);
14831542
}
14841543

1485-
// Initialize tangent mapping for indirect results.
1486-
auto origIndResults = original->getIndirectResults();
1544+
// Initialize tangent mapping for original indirect results and non-wrt
1545+
// `inout` parameters. The tangent buffers of these address values are
1546+
// differential indirect results.
1547+
1548+
// Collect original results.
1549+
SmallVector<SILValue, 2> originalResults;
1550+
collectAllFormalResultsInTypeOrder(*original, originalResults);
1551+
1552+
// Iterate over differentiability results.
1553+
differentialBuilder.setInsertionPoint(differential.getEntryBlock());
14871554
auto diffIndResults = differential.getIndirectResults();
1488-
#ifndef NDEBUG
1489-
unsigned numNonWrtInoutParameters = llvm::count_if(
1490-
range(original->getLoweredFunctionType()->getNumParameters()),
1491-
[&] (unsigned i) {
1492-
auto &paramInfo = original->getLoweredFunctionType()->getParameters()[i];
1493-
return paramInfo.isIndirectInOut() && !getIndices().parameters->contains(i);
1494-
});
1495-
#endif
1496-
assert(origIndResults.size() + numNonWrtInoutParameters == diffIndResults.size());
1497-
for (auto &origBB : *original)
1498-
for (auto i : indices(origIndResults))
1499-
setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
1555+
unsigned differentialIndirectResultIndex = 0;
1556+
for (auto resultIndex : getIndices().results->getIndices()) {
1557+
auto origResult = originalResults[resultIndex];
1558+
// Handle original formal indirect result.
1559+
if (resultIndex < origFnTy->getNumResults()) {
1560+
// Skip original direct results.
1561+
if (origResult->getType().isObject())
1562+
continue;
1563+
auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1564+
setTangentBuffer(origEntry, origResult, diffIndResult);
1565+
// If original indirect result is non-varied, zero-initialize its tangent
1566+
// buffer.
1567+
if (!activityInfo.isVaried(origResult, getIndices().parameters))
1568+
emitZeroIndirect(diffIndResult->getType().getASTType(),
1569+
diffIndResult, diffLoc);
1570+
continue;
1571+
}
1572+
// Handle original non-wrt `inout` parameter.
1573+
// Only original *non-wrt* `inout` parameters have corresponding
1574+
// differential indirect results.
1575+
auto inoutParamIndex = resultIndex - origFnTy->getNumResults();
1576+
auto inoutParamIt = std::next(
1577+
origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
1578+
auto paramIndex =
1579+
std::distance(origFnTy->getParameters().begin(), &*inoutParamIt);
1580+
if (getIndices().parameters->contains(paramIndex))
1581+
continue;
1582+
auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1583+
setTangentBuffer(origEntry, origResult, diffIndResult);
1584+
// Original `inout` parameters are initialized, so their tangent buffers
1585+
// must also be initialized.
1586+
emitZeroIndirect(diffIndResult->getType().getASTType(),
1587+
diffIndResult, diffLoc);
1588+
}
15001589
}
15011590

15021591
/*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential(
@@ -1526,7 +1615,6 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15261615
auto origParams = origTy->getParameters();
15271616
auto indices = witness->getSILAutoDiffIndices();
15281617

1529-
15301618
for (auto resultIndex : indices.results->getIndices()) {
15311619
if (resultIndex < origTy->getNumResults()) {
15321620
// Handle formal original result.
@@ -1539,17 +1627,16 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15391627
->getType()
15401628
->getCanonicalType(witnessCanGenSig),
15411629
origResult.getConvention()));
1542-
}
1543-
else {
1630+
} else {
15441631
// Handle original `inout` parameter.
15451632
auto inoutParamIndex = resultIndex - origTy->getNumResults();
15461633
auto inoutParamIt = std::next(
15471634
origTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
15481635
auto paramIndex =
15491636
std::distance(origTy->getParameters().begin(), &*inoutParamIt);
1550-
// If the original `inout` parameter is a differentiability parameter, then
1551-
// it already has a corresponding differential parameter. Skip adding a
1552-
// corresponding differential result.
1637+
// If the original `inout` parameter is a differentiability parameter,
1638+
// then it already has a corresponding differential parameter. Do not add
1639+
// a corresponding differential result.
15531640
if (indices.parameters->contains(paramIndex))
15541641
continue;
15551642
auto inoutParam = origTy->getParameters()[paramIndex];

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
460460
/// 3. The instruction has both an active result (direct or indirect) and an
461461
/// active argument.
462462
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
463-
// Function applications with an inout argument should be differentiated.
463+
// Function applications with an active inout argument should be
464+
// differentiated.
464465
for (auto inoutArg : applySite.getInoutArguments())
465466
if (activityInfo.isActive(inoutArg, indices))
466467
return true;

0 commit comments

Comments
 (0)