Skip to content

Commit 1cf3b64

Browse files
Merge pull request #67258 from nate-chandler/pullback_cloner/20230712/1
[PullbackCloner] Handled move_value instructions.
2 parents aa026f4 + 19f8260 commit 1cf3b64

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,46 +1535,45 @@ class PullbackCloner::Implementation final
15351535
builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization);
15361536
}
15371537

1538-
/// Handle `copy_value` instruction.
1538+
/// Handle any ownership instruction that deals with values: copy_value,
1539+
/// move_value, begin_borrow.
15391540
/// Original: y = copy_value x
15401541
/// Adjoint: adj[x] += adj[y]
1541-
void visitCopyValueInst(CopyValueInst *cvi) {
1542-
auto *bb = cvi->getParent();
1543-
switch (getTangentValueCategory(cvi)) {
1542+
void visitValueOwnershipInst(SingleValueInstruction *svi) {
1543+
assert(svi->getNumOperands() == 1);
1544+
auto *bb = svi->getParent();
1545+
switch (getTangentValueCategory(svi)) {
15441546
case SILValueCategory::Object: {
1545-
auto adj = getAdjointValue(bb, cvi);
1546-
addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc());
1547+
auto adj = getAdjointValue(bb, svi);
1548+
addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc());
15471549
break;
15481550
}
15491551
case SILValueCategory::Address: {
1550-
auto adjDest = getAdjointBuffer(bb, cvi);
1551-
addToAdjointBuffer(bb, cvi->getOperand(), adjDest, cvi->getLoc());
1552-
builder.emitZeroIntoBuffer(cvi->getLoc(), adjDest, IsNotInitialization);
1552+
auto adjDest = getAdjointBuffer(bb, svi);
1553+
addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc());
1554+
builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization);
15531555
break;
15541556
}
15551557
}
15561558
}
15571559

1560+
/// Handle `copy_value` instruction.
1561+
/// Original: y = copy_value x
1562+
/// Adjoint: adj[x] += adj[y]
1563+
void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); }
1564+
15581565
/// Handle `begin_borrow` instruction.
15591566
/// Original: y = begin_borrow x
15601567
/// Adjoint: adj[x] += adj[y]
15611568
void visitBeginBorrowInst(BeginBorrowInst *bbi) {
1562-
auto *bb = bbi->getParent();
1563-
switch (getTangentValueCategory(bbi)) {
1564-
case SILValueCategory::Object: {
1565-
auto adj = getAdjointValue(bb, bbi);
1566-
addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc());
1567-
break;
1568-
}
1569-
case SILValueCategory::Address: {
1570-
auto adjDest = getAdjointBuffer(bb, bbi);
1571-
addToAdjointBuffer(bb, bbi->getOperand(), adjDest, bbi->getLoc());
1572-
builder.emitZeroIntoBuffer(bbi->getLoc(), adjDest, IsNotInitialization);
1573-
break;
1574-
}
1575-
}
1569+
visitValueOwnershipInst(bbi);
15761570
}
15771571

1572+
/// Handle `move_value` instruction.
1573+
/// Original: y = move_value x
1574+
/// Adjoint: adj[x] += adj[y]
1575+
void visitMoveValueInst(MoveValueInst *mvi) { visitValueOwnershipInst(mvi); }
1576+
15781577
/// Handle `begin_access` instruction.
15791578
/// Original: y = begin_access x
15801579
/// Adjoint: nothing

0 commit comments

Comments
 (0)