Skip to content

Commit cfd75a1

Browse files
authored
[AutoDiff] Fix forward-mode differentiation ownership verification failure. (#33898)
Previously, `LinearMapInfo::shouldDifferentiateInstruction` had a special case for `copy_value`, returning true for `copy_value` instructions with an active operand. This is unexpected and led to "leaked owned value" ownership verification failures due to unnecessarily cloned `copy_value` instructions during differential generation. Now, the special case is removed, fixing the failures. `shouldDifferentiateInstruction` returns true for `copy_value` instructions whose operand and result are both active. Resolves SR-13530.
1 parent b2fa269 commit cfd75a1

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
537537
isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) ||
538538
isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst))
539539
return true;
540-
// Should differentiate any instruction that creates an SSA copy of an
541-
// active operand.
542-
if (isa<CopyValueInst>(inst))
543-
return true;
544540
}
545541
return false;
546542
}

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,4 +1389,15 @@ ForwardModeTests.test("ApplyNonActiveIndirectResult") {
13891389
expectEqual(1.0, derivative(at: 2, in: applyNonactiveArgumentActiveIndirectResult))
13901390
}
13911391

1392+
ForwardModeTests.test("SR-13530") {
1393+
// SR-13530: Test "leaked owned value" ownership verification failure related
1394+
// to differential generation for `copy_value` instruction.
1395+
@differentiable
1396+
func SR_13530(_ x: NonresilientTracked<Float>) -> NonresilientTracked<Float> {
1397+
precondition(x >= 0)
1398+
return x
1399+
}
1400+
expectEqual(1, derivative(at: 2, in: SR_13530))
1401+
}
1402+
13921403
runAllTests()

0 commit comments

Comments
 (0)