Skip to content

Commit 11dd84e

Browse files
authored
[AutoDiff] Ensure adjoint buffer of unchecked_take_enum_data_addr is zeroed after value is accumulated (#58457)
This is yet another optional-related bug. It could be triggered if `unchecked_take_enum_data_addr` is inside a loop. In such case the accumulated value of adjoint buffer would be re-used in the subsequent loop iterations producing wrong results. Fixes #58353 (SR16094)
1 parent 699e464 commit 11dd84e

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ class PullbackCloner::Implementation final
549549
if (auto adjProj = getAdjointProjection(origBB, originalValue))
550550
return (bufferMap[{origBB, originalValue}] = adjProj);
551551

552-
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for"
552+
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for "
553553
<< originalValue
554554
<< "in bb" << origBB->getDebugID() << '\n');
555555

@@ -589,7 +589,8 @@ class PullbackCloner::Implementation final
589589
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
590590

591591
LLVM_DEBUG(getADDebugStream() << "Adding"
592-
<< rhsAddress << "to adjoint of "
592+
<< rhsAddress << "to adjoint ("
593+
<< adjointBuffer << ") of "
593594
<< originalValue
594595
<< "in bb" << origBB->getDebugID() << '\n');
595596

@@ -811,7 +812,8 @@ class PullbackCloner::Implementation final
811812
#endif
812813
SILInstructionVisitor::visit(inst);
813814
LLVM_DEBUG({
814-
auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n";
815+
auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" <<
816+
builder.getInsertionBB()->getDebugID() << "):\n";
815817
auto afterInsertion = builder.getInsertionPoint();
816818
for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
817819
s << *it;
@@ -1645,7 +1647,7 @@ class PullbackCloner::Implementation final
16451647
void
16461648
visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) {
16471649
auto *bb = utedai->getParent();
1648-
auto adjBuf = getAdjointBuffer(bb, utedai);
1650+
auto adjDest = getAdjointBuffer(bb, utedai);
16491651
auto enumTy = utedai->getOperand()->getType();
16501652
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
16511653
// Only `Optional`-typed operands are supported for now. Diagnose all other
@@ -1659,7 +1661,8 @@ class PullbackCloner::Implementation final
16591661
errorOccurred = true;
16601662
return;
16611663
}
1662-
accumulateAdjointForOptional(bb, utedai->getOperand(), adjBuf);
1664+
accumulateAdjointForOptional(bb, utedai->getOperand(), adjDest);
1665+
builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization);
16631666
}
16641667

16651668
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
@@ -2473,6 +2476,10 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
24732476
for (auto *bbArg : bb->getArguments()) {
24742477
if (!getActivityInfo().isActive(bbArg, getConfig()))
24752478
continue;
2479+
LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb"
2480+
<< bb->getDebugID() << " argument: "
2481+
<< *bbArg);
2482+
24762483
// Get predecessor terminator operands.
24772484
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
24782485
bbArg->getSingleTerminatorOperands(incomingValues);

0 commit comments

Comments
 (0)