Skip to content

Commit e3938b5

Browse files
committed
[AutoDiff] Fix Array.append(_:) pullback.
The derivative wrt `self` should drop the last element from the incoming seed. Example: - Incoming seed: [1, 2, 3, 4] - Derivative wrt `self`: [1, 2, 3] - Derivative wrt appended element: 4
1 parent 2e690e2 commit e3938b5

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,11 @@ extension Array where Element: Differentiable {
222222
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
223223
) {
224224
let appendedElementIndex = count
225-
defer { append(element) }
226-
return ((), { dself in dself.base[appendedElementIndex] })
225+
append(element)
226+
return ((), { v in
227+
defer { v.base.removeLast() }
228+
return v.base[appendedElementIndex]
229+
})
227230
}
228231

229232
@usableFromInline

test/AutoDiff/validation-test/array.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,28 @@ ArrayAutoDiffTests.test("Array.+=") {
372372
}
373373

374374
ArrayAutoDiffTests.test("Array.append") {
375+
func appending(_ array: [Float], _ element: Float) -> [Float] {
376+
var result = array
377+
result.append(element)
378+
return result
379+
}
380+
do {
381+
let v = FloatArrayTan([1, 2, 3, 4])
382+
expectEqual((.init([1, 2, 3]), 4),
383+
pullback(at: [0, 0, 0], 0, in: appending)(v))
384+
}
385+
375386
func identity(_ array: [Float]) -> [Float] {
376387
var results: [Float] = []
377388
for i in withoutDerivative(at: array.indices) {
378389
results.append(array[i])
379390
}
380391
return results
381392
}
382-
let v = FloatArrayTan([4, -5, 6])
383-
expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v))
393+
do {
394+
let v = FloatArrayTan([4, -5, 6])
395+
expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v))
396+
}
384397
}
385398

386399
ArrayAutoDiffTests.test("Array.init(repeating:count:)") {

0 commit comments

Comments
 (0)