Skip to content

Commit 30fccb4

Browse files
authored
Merge pull request #31781 from dan-zheng/array-append-derivative
[AutoDiff] Fix `Array.append(_:)` pullback.
2 parents 8f39567 + e3938b5 commit 30fccb4

File tree

2 files changed

+62
-51
lines changed

2 files changed

+62
-51
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,31 @@ extension Array where Element: Differentiable {
184184
func _vjpSubscript(index: Int) -> (
185185
value: Element, pullback: (Element.TangentVector) -> TangentVector
186186
) {
187-
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
188-
var gradientOut = [Element.TangentVector](
187+
func pullback(_ v: Element.TangentVector) -> TangentVector {
188+
var dSelf = [Element.TangentVector](
189189
repeating: .zero,
190190
count: count)
191-
gradientOut[index] = gradientIn
192-
return TangentVector(gradientOut)
191+
dSelf[index] = v
192+
return TangentVector(dSelf)
193193
}
194194
return (self[index], pullback)
195195
}
196196

197197
@usableFromInline
198198
@derivative(of: +)
199-
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
200-
value: [Element],
199+
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
200+
value: Self,
201201
pullback: (TangentVector) -> (TangentVector, TangentVector)
202202
) {
203-
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
204-
{
203+
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
205204
precondition(
206-
gradientIn.base.count == lhs.count + rhs.count,
205+
v.base.count == lhs.count + rhs.count,
207206
"+ should receive gradient with count equal to sum of operand "
208-
+ "counts, but counts are: gradient \(gradientIn.base.count), "
207+
+ "counts, but counts are: gradient \(v.base.count), "
209208
+ "lhs \(lhs.count), rhs \(rhs.count)")
210209
return (
211-
TangentVector(
212-
[Element.TangentVector](
213-
gradientIn.base[0..<lhs.count])),
214-
TangentVector(
215-
[Element.TangentVector](
216-
gradientIn.base[lhs.count...]))
210+
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),
211+
TangentVector([Element.TangentVector](v.base[lhs.count...]))
217212
)
218213
}
219214
return (lhs + rhs, pullback)
@@ -227,8 +222,11 @@ extension Array where Element: Differentiable {
227222
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
228223
) {
229224
let appendedElementIndex = count
230-
defer { append(element) }
231-
return ((), { dself in dself.base[appendedElementIndex] })
225+
append(element)
226+
return ((), { v in
227+
defer { v.base.removeLast() }
228+
return v.base[appendedElementIndex]
229+
})
232230
}
233231

234232
@usableFromInline
@@ -288,6 +286,37 @@ extension Array where Element: Differentiable {
288286
// Differentiable higher order functions for collections
289287
//===----------------------------------------------------------------------===//
290288

289+
extension Array where Element: Differentiable {
290+
@inlinable
291+
@differentiable(wrt: self)
292+
public func differentiableMap<Result: Differentiable>(
293+
_ body: @differentiable (Element) -> Result
294+
) -> [Result] {
295+
map(body)
296+
}
297+
298+
@inlinable
299+
@derivative(of: differentiableMap)
300+
internal func _vjpDifferentiableMap<Result: Differentiable>(
301+
_ body: @differentiable (Element) -> Result
302+
) -> (
303+
value: [Result],
304+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
305+
) {
306+
var values: [Result] = []
307+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
308+
for x in self {
309+
let (y, pb) = valueWithPullback(at: x, in: body)
310+
values.append(y)
311+
pullbacks.append(pb)
312+
}
313+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
314+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
315+
}
316+
return (value: values, pullback: pullback)
317+
}
318+
}
319+
291320
extension Array where Element: Differentiable {
292321
@inlinable
293322
@differentiable(wrt: (self, initialResult))
@@ -336,34 +365,3 @@ extension Array where Element: Differentiable {
336365
)
337366
}
338367
}
339-
340-
extension Array where Element: Differentiable {
341-
@inlinable
342-
@differentiable(wrt: self)
343-
public func differentiableMap<Result: Differentiable>(
344-
_ body: @differentiable (Element) -> Result
345-
) -> [Result] {
346-
map(body)
347-
}
348-
349-
@inlinable
350-
@derivative(of: differentiableMap)
351-
internal func _vjpDifferentiableMap<Result: Differentiable>(
352-
_ body: @differentiable (Element) -> Result
353-
) -> (
354-
value: [Result],
355-
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
356-
) {
357-
var values: [Result] = []
358-
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
359-
for x in self {
360-
let (y, pb) = valueWithPullback(at: x, in: body)
361-
values.append(y)
362-
pullbacks.append(pb)
363-
}
364-
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
365-
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
366-
}
367-
return (value: values, pullback: pullback)
368-
}
369-
}

test/AutoDiff/stdlib/array.swift renamed to test/AutoDiff/validation-test/array.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,28 @@ ArrayAutoDiffTests.test("Array.+=") {
372372
}
373373

374374
ArrayAutoDiffTests.test("Array.append") {
375+
func appending(_ array: [Float], _ element: Float) -> [Float] {
376+
var result = array
377+
result.append(element)
378+
return result
379+
}
380+
do {
381+
let v = FloatArrayTan([1, 2, 3, 4])
382+
expectEqual((.init([1, 2, 3]), 4),
383+
pullback(at: [0, 0, 0], 0, in: appending)(v))
384+
}
385+
375386
func identity(_ array: [Float]) -> [Float] {
376387
var results: [Float] = []
377388
for i in withoutDerivative(at: array.indices) {
378389
results.append(array[i])
379390
}
380391
return results
381392
}
382-
let v = FloatArrayTan([4, -5, 6])
383-
expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v))
393+
do {
394+
let v = FloatArrayTan([4, -5, 6])
395+
expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v))
396+
}
384397
}
385398

386399
ArrayAutoDiffTests.test("Array.init(repeating:count:)") {

0 commit comments

Comments
 (0)