Skip to content

Commit 8829a14

Browse files
authored
[AutoDiff] Register derivatives for Array.+=. (#31782)
Register JVP and VJP for `Array.+=`. Simplify and use same array concatenation derivative tests as `Array.+`.
1 parent 7f3381e commit 8829a14

File tree

2 files changed

+62
-24
lines changed

2 files changed

+62
-24
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,33 @@ extension Array where Element: Differentiable {
242242
}
243243
}
244244

245+
extension Array where Element: Differentiable {
246+
@usableFromInline
247+
@derivative(of: +=)
248+
static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> (
249+
value: Void, pullback: (inout TangentVector) -> TangentVector
250+
) {
251+
let lhsCount = lhs.count
252+
lhs += rhs
253+
return ((), { v in
254+
let drhs =
255+
TangentVector(.init(v.base.dropFirst(lhsCount)))
256+
let rhsCount = drhs.base.count
257+
v.base.removeLast(rhsCount)
258+
return drhs
259+
})
260+
}
261+
262+
@usableFromInline
263+
@derivative(of: +=)
264+
static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> (
265+
value: Void, differential: (inout TangentVector, TangentVector) -> Void
266+
) {
267+
lhs += rhs
268+
return ((), { $0.base += $1.base })
269+
}
270+
}
271+
245272
extension Array where Element: Differentiable {
246273
@usableFromInline
247274
@derivative(of: init(repeating:count:))

test/AutoDiff/stdlib/array.swift

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -317,42 +317,53 @@ ArrayAutoDiffTests.test("ExpressibleByArrayLiteralIndirect") {
317317
}
318318

319319
ArrayAutoDiffTests.test("Array.+") {
320-
struct TwoArrays : Differentiable {
321-
var a: [Float]
322-
var b: [Float]
320+
func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float {
321+
let c = a + b
322+
return c[0] + c[1] + c[2]
323323
}
324324

325-
func sumFirstThreeConcatenated(_ arrs: TwoArrays) -> Float {
326-
let c = arrs.a + arrs.b
325+
expectEqual(
326+
(.init([1, 1]), .init([1, 0])),
327+
gradient(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating))
328+
expectEqual(
329+
(.init([1, 1, 1, 0]), .init([0, 0])),
330+
gradient(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating))
331+
expectEqual(
332+
(.init([]), .init([1, 1, 1, 0])),
333+
gradient(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating))
334+
335+
func identity(_ array: [Float]) -> [Float] {
336+
var results: [Float] = []
337+
for i in withoutDerivative(at: array.indices) {
338+
results = results + [array[i]]
339+
}
340+
return results
341+
}
342+
let v = FloatArrayTan([4, -5, 6])
343+
expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v))
344+
}
345+
346+
ArrayAutoDiffTests.test("Array.+=") {
347+
func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float {
348+
var c = a
349+
c += b
327350
return c[0] + c[1] + c[2]
328351
}
329352

330353
expectEqual(
331-
TwoArrays.TangentVector(
332-
a: FloatArrayTan([1, 1]),
333-
b: FloatArrayTan([1, 0])),
334-
gradient(
335-
at: TwoArrays(a: [0, 0], b: [0, 0]),
336-
in: sumFirstThreeConcatenated))
354+
(.init([1, 1]), .init([1, 0])),
355+
gradient(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating))
337356
expectEqual(
338-
TwoArrays.TangentVector(
339-
a: FloatArrayTan([1, 1, 1, 0]),
340-
b: FloatArrayTan([0, 0])),
341-
gradient(
342-
at: TwoArrays(a: [0, 0, 0, 0], b: [0, 0]),
343-
in: sumFirstThreeConcatenated))
357+
(.init([1, 1, 1, 0]), .init([0, 0])),
358+
gradient(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating))
344359
expectEqual(
345-
TwoArrays.TangentVector(
346-
a: FloatArrayTan([]),
347-
b: FloatArrayTan([1, 1, 1, 0])),
348-
gradient(
349-
at: TwoArrays(a: [], b: [0, 0, 0, 0]),
350-
in: sumFirstThreeConcatenated))
360+
(.init([]), .init([1, 1, 1, 0])),
361+
gradient(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating))
351362

352363
func identity(_ array: [Float]) -> [Float] {
353364
var results: [Float] = []
354365
for i in withoutDerivative(at: array.indices) {
355-
results = results + [array[i]]
366+
results += [array[i]]
356367
}
357368
return results
358369
}

0 commit comments

Comments
 (0)