Skip to content

Commit 2ee26da

Browse files
committed
Change getAdjointBuffer to return SILValue instead of SILValue &.
Fixes a use-after-free crash.
1 parent 596787b commit 2ee26da

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ class PullbackCloner::Implementation final
562562
///
563563
/// This method first tries to find an existing entry in the adjoint buffer
564564
/// mapping. If no entry exists, creates a zero adjoint buffer.
565-
SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
565+
SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
566566
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
567567
assert(originalValue->getFunction() == &getOriginal());
568568
auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue());
@@ -1470,7 +1470,7 @@ class PullbackCloner::Implementation final
14701470
/// Adjoint: adj[x] += load adj[y]; adj[y] = 0
14711471
void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc,
14721472
SILValue origDest) {
1473-
auto &adjBuf = getAdjointBuffer(bb, origDest);
1473+
auto adjBuf = getAdjointBuffer(bb, origDest);
14741474
switch (getTangentValueCategory(origSrc)) {
14751475
case SILValueCategory::Object: {
14761476
auto adjVal = builder.emitLoadValueOperation(
@@ -1502,7 +1502,7 @@ class PullbackCloner::Implementation final
15021502
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
15031503
void visitCopyAddrInst(CopyAddrInst *cai) {
15041504
auto *bb = cai->getParent();
1505-
auto &adjDest = getAdjointBuffer(bb, cai->getDest());
1505+
auto adjDest = getAdjointBuffer(bb, cai->getDest());
15061506
auto destType = remapType(adjDest->getType());
15071507
addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
15081508
builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest);
@@ -1521,7 +1521,7 @@ class PullbackCloner::Implementation final
15211521
break;
15221522
}
15231523
case SILValueCategory::Address: {
1524-
auto &adjDest = getAdjointBuffer(bb, cvi);
1524+
auto adjDest = getAdjointBuffer(bb, cvi);
15251525
auto destType = remapType(adjDest->getType());
15261526
addToAdjointBuffer(bb, cvi->getOperand(), adjDest, cvi->getLoc());
15271527
builder.emitDestroyAddrAndFold(cvi->getLoc(), adjDest);
@@ -1543,7 +1543,7 @@ class PullbackCloner::Implementation final
15431543
break;
15441544
}
15451545
case SILValueCategory::Address: {
1546-
auto &adjDest = getAdjointBuffer(bb, bbi);
1546+
auto adjDest = getAdjointBuffer(bb, bbi);
15471547
auto destType = remapType(adjDest->getType());
15481548
addToAdjointBuffer(bb, bbi->getOperand(), adjDest, bbi->getLoc());
15491549
builder.emitDestroyAddrAndFold(bbi->getLoc(), adjDest);
@@ -1582,8 +1582,8 @@ class PullbackCloner::Implementation final
15821582
void visitUnconditionalCheckedCastAddrInst(
15831583
UnconditionalCheckedCastAddrInst *uccai) {
15841584
auto *bb = uccai->getParent();
1585-
auto &adjDest = getAdjointBuffer(bb, uccai->getDest());
1586-
auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc());
1585+
auto adjDest = getAdjointBuffer(bb, uccai->getDest());
1586+
auto adjSrc = getAdjointBuffer(bb, uccai->getSrc());
15871587
auto destType = remapType(adjDest->getType());
15881588
auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
15891589
builder.createUnconditionalCheckedCastAddr(
@@ -1612,7 +1612,7 @@ class PullbackCloner::Implementation final
16121612
break;
16131613
}
16141614
case SILValueCategory::Address: {
1615-
auto &adjDest = getAdjointBuffer(bb, urci);
1615+
auto adjDest = getAdjointBuffer(bb, urci);
16161616
auto destType = remapType(adjDest->getType());
16171617
addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc());
16181618
builder.emitDestroyAddrAndFold(urci->getLoc(), adjDest);
@@ -1639,7 +1639,7 @@ class PullbackCloner::Implementation final
16391639
break;
16401640
}
16411641
case SILValueCategory::Address: {
1642-
auto &adjDest = getAdjointBuffer(bb, ui);
1642+
auto adjDest = getAdjointBuffer(bb, ui);
16431643
auto destType = remapType(adjDest->getType());
16441644
addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc());
16451645
builder.emitDestroyAddrAndFold(ui->getLoc(), adjDest);

0 commit comments

Comments
 (0)