@@ -549,7 +549,7 @@ class PullbackCloner::Implementation final
549
549
if (auto adjProj = getAdjointProjection (origBB, originalValue))
550
550
return (bufferMap[{origBB, originalValue}] = adjProj);
551
551
552
- LLVM_DEBUG (getADDebugStream () << " Creating new adjoint buffer for"
552
+ LLVM_DEBUG (getADDebugStream () << " Creating new adjoint buffer for "
553
553
<< originalValue
554
554
<< " in bb" << origBB->getDebugID () << ' \n ' );
555
555
@@ -589,7 +589,8 @@ class PullbackCloner::Implementation final
589
589
auto adjointBuffer = getAdjointBuffer (origBB, originalValue);
590
590
591
591
LLVM_DEBUG (getADDebugStream () << " Adding"
592
- << rhsAddress << " to adjoint of "
592
+ << rhsAddress << " to adjoint ("
593
+ << adjointBuffer << " ) of "
593
594
<< originalValue
594
595
<< " in bb" << origBB->getDebugID () << ' \n ' );
595
596
@@ -811,7 +812,8 @@ class PullbackCloner::Implementation final
811
812
#endif
812
813
SILInstructionVisitor::visit (inst);
813
814
LLVM_DEBUG ({
814
- auto &s = llvm::dbgs () << " [ADJ] Emitted in pullback:\n " ;
815
+ auto &s = llvm::dbgs () << " [ADJ] Emitted in pullback (pb bb" <<
816
+ builder.getInsertionBB ()->getDebugID () << " ):\n " ;
815
817
auto afterInsertion = builder.getInsertionPoint ();
816
818
for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
817
819
s << *it;
@@ -1645,7 +1647,7 @@ class PullbackCloner::Implementation final
1645
1647
void
1646
1648
visitUncheckedTakeEnumDataAddrInst (UncheckedTakeEnumDataAddrInst *utedai) {
1647
1649
auto *bb = utedai->getParent ();
1648
- auto adjBuf = getAdjointBuffer (bb, utedai);
1650
+ auto adjDest = getAdjointBuffer (bb, utedai);
1649
1651
auto enumTy = utedai->getOperand ()->getType ();
1650
1652
auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
1651
1653
// Only `Optional`-typed operands are supported for now. Diagnose all other
@@ -1659,7 +1661,8 @@ class PullbackCloner::Implementation final
1659
1661
errorOccurred = true ;
1660
1662
return ;
1661
1663
}
1662
- accumulateAdjointForOptional (bb, utedai->getOperand (), adjBuf);
1664
+ accumulateAdjointForOptional (bb, utedai->getOperand (), adjDest);
1665
+ builder.emitZeroIntoBuffer (utedai->getLoc (), adjDest, IsNotInitialization);
1663
1666
}
1664
1667
1665
1668
#define NOT_DIFFERENTIABLE (INST, DIAG ) void visit##INST##Inst(INST##Inst *inst);
@@ -2473,6 +2476,10 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2473
2476
for (auto *bbArg : bb->getArguments ()) {
2474
2477
if (!getActivityInfo ().isActive (bbArg, getConfig ()))
2475
2478
continue ;
2479
+ LLVM_DEBUG (getADDebugStream () << " Propagating adjoint value for active bb"
2480
+ << bb->getDebugID () << " argument: "
2481
+ << *bbArg);
2482
+
2476
2483
// Get predecessor terminator operands.
2477
2484
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4 > incomingValues;
2478
2485
bbArg->getSingleTerminatorOperands (incomingValues);
0 commit comments