Skip to content

Commit d31f3ac

Browse files
committed
[AutoDiff][stdlib] Add JVPs to ArrayDifferentiation.swift
`_jvpDifferentiableReduce` is taken from swiftlang#29324
1 parent 45201e5 commit d31f3ac

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ where Element: Differentiable {
4242
return (base, { $0 })
4343
}
4444

45+
@usableFromInline
46+
@derivative(of: base)
47+
func _jvpBase() -> (
48+
value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector
49+
) {
50+
return (base, { $0 })
51+
}
52+
4553
/// Creates a differentiable view of the given array.
4654
public init(_ base: [Element]) { self._base = base }
4755

@@ -53,6 +61,14 @@ where Element: Differentiable {
5361
return (Array.DifferentiableView(base), { $0 })
5462
}
5563

64+
@usableFromInline
65+
@derivative(of: init(_:))
66+
static func _jvpInit(_ base: [Element]) -> (
67+
value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector
68+
) {
69+
return (Array.DifferentiableView(base), { $0 })
70+
}
71+
5672
public typealias TangentVector =
5773
Array<Element.TangentVector>.DifferentiableView
5874

@@ -191,6 +207,17 @@ extension Array where Element: Differentiable {
191207
return (self[index], pullback)
192208
}
193209

210+
@usableFromInline
211+
@derivative(of: subscript)
212+
func _jvpSubscript(index: Int) -> (
213+
value: Element, differential: (TangentVector) -> Element.TangentVector
214+
) {
215+
func differential(_ v: TangentVector) -> Element.TangentVector {
216+
return v[index]
217+
}
218+
return (self[index], differential)
219+
}
220+
194221
@usableFromInline
195222
@derivative(of: +)
196223
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
@@ -210,8 +237,26 @@ extension Array where Element: Differentiable {
210237
}
211238
return (lhs + rhs, pullback)
212239
}
240+
241+
@usableFromInline
242+
@derivative(of: +)
243+
static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> (
244+
value: Self,
245+
differential: (TangentVector, TangentVector) -> TangentVector
246+
) {
247+
func differential(_ l: TangentVector, _ r: TangentVector) -> TangentVector {
248+
precondition(
249+
l.base.count == lhs.count && r.base.count == rhs.count, """
250+
Tangent vectors with invalid count; expected to equal the \
251+
operand counts \(lhs.count) and \(rhs.count)
252+
""")
253+
return .init(l.base + r.base)
254+
}
255+
return (lhs + rhs, differential)
256+
}
213257
}
214258

259+
215260
extension Array where Element: Differentiable {
216261
@usableFromInline
217262
@derivative(of: append)
@@ -277,6 +322,17 @@ extension Array where Element: Differentiable {
277322
}
278323
)
279324
}
325+
326+
@usableFromInline
327+
@derivative(of: init(repeating:count:))
328+
static func _jvpInit(repeating repeatedValue: Element, count: Int) -> (
329+
value: Self, differential: (Element.TangentVector) -> TangentVector
330+
) {
331+
(
332+
value: Self(repeating: repeatedValue, count: count),
333+
differential: { v in TangentVector(.init(repeating: v, count: count)) }
334+
)
335+
}
280336
}
281337

282338
//===----------------------------------------------------------------------===//
@@ -382,4 +438,33 @@ extension Array where Element: Differentiable {
382438
}
383439
)
384440
}
441+
442+
@inlinable
443+
@derivative(of: differentiableReduce, wrt: (self, initialResult))
444+
func _jvpDifferentiableReduce<Result: Differentiable>(
445+
_ initialResult: Result,
446+
_ nextPartialResult: @differentiable (Result, Element) -> Result
447+
) -> (value: Result,
448+
differential: (Array.TangentVector, Result.TangentVector)
449+
-> Result.TangentVector) {
450+
var differentials:
451+
[(Result.TangentVector, Element.TangentVector) -> Result.TangentVector]
452+
= []
453+
let count = self.count
454+
differentials.reserveCapacity(count)
455+
var result = initialResult
456+
for element in self {
457+
let (y, df) =
458+
valueWithDifferential(at: result, element, in: nextPartialResult)
459+
result = y
460+
differentials.append(df)
461+
}
462+
return (value: result, differential: { dSelf, dInitial in
463+
var dResult = dInitial
464+
for (dElement, df) in zip(dSelf.base, differentials) {
465+
dResult = df(dResult, dElement)
466+
}
467+
return dResult
468+
})
469+
}
385470
}

test/AutoDiff/validation-test/forward_mode.swift

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,75 @@ ForwardModeTests.test("ForceUnwrapping") {
13231323
// Array methods from ArrayDifferentiation.swift
13241324
//===----------------------------------------------------------------------===//
13251325

1326+
typealias FloatArrayTan = Array<Float>.TangentVector
1327+
1328+
ForwardModeTests.test("Array.+") {
1329+
func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float {
1330+
let c = a + b
1331+
return c[0] + c[1] + c[2]
1332+
}
1333+
1334+
expectEqual(3, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1]), .init([1, 1])))
1335+
expectEqual(0, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 0]), .init([0, 1])))
1336+
expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 1]), .init([0, 1])))
1337+
expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 0]), .init([0, 1])))
1338+
expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 0]), .init([1, 1])))
1339+
expectEqual(2, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1]), .init([0, 1])))
1340+
1341+
expectEqual(
1342+
3,
1343+
differential(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1, 1, 1]), .init([1, 1])))
1344+
expectEqual(
1345+
3,
1346+
differential(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1, 1, 0]), .init([0, 0])))
1347+
1348+
expectEqual(
1349+
3,
1350+
differential(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)(.init([]), .init([1, 1, 1, 1])))
1351+
expectEqual(
1352+
0,
1353+
differential(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)(.init([]), .init([0, 0, 0, 1])))
1354+
}
1355+
1356+
ForwardModeTests.test("Array.init(repeating:count:)") {
1357+
@differentiable
1358+
func repeating(_ x: Float) -> [Float] {
1359+
Array(repeating: x, count: 10)
1360+
}
1361+
expectEqual(Float(10), derivative(at: .zero) { x in
1362+
repeating(x).differentiableReduce(0, {$0 + $1})
1363+
})
1364+
expectEqual(Float(20), differential(at: .zero, in: { x in
1365+
repeating(x).differentiableReduce(0, {$0 + $1})
1366+
})(2))
1367+
}
1368+
1369+
ForwardModeTests.test("Array.DifferentiableView.init") {
1370+
@differentiable
1371+
func constructView(_ x: [Float]) -> Array<Float>.DifferentiableView {
1372+
return Array<Float>.DifferentiableView(x)
1373+
}
1374+
1375+
let forward = differential(at: [5, 6, 7, 8], in: constructView)
1376+
expectEqual(
1377+
FloatArrayTan([1, 2, 3, 4]),
1378+
forward(FloatArrayTan([1, 2, 3, 4])))
1379+
}
1380+
1381+
ForwardModeTests.test("Array.DifferentiableView.base") {
1382+
@differentiable
1383+
func accessBase(_ x: Array<Float>.DifferentiableView) -> [Float] {
1384+
return x.base
1385+
}
1386+
1387+
let forward = differential(
1388+
at: Array<Float>.DifferentiableView([5, 6, 7, 8]),
1389+
in: accessBase)
1390+
expectEqual(
1391+
FloatArrayTan([1, 2, 3, 4]),
1392+
forward(FloatArrayTan([1, 2, 3, 4])))
1393+
}
1394+
13261395
ForwardModeTests.test("Array.differentiableMap") {
13271396
let x: [Float] = [1, 2, 3]
13281397
let tan = Array<Float>.TangentVector([1, 1, 1])
@@ -1338,4 +1407,24 @@ ForwardModeTests.test("Array.differentiableMap") {
13381407
expectEqual([2, 4, 6], differential(at: x, in: squareMap)(tan))
13391408
}
13401409

1410+
ForwardModeTests.test("Array.differentiableReduce") {
1411+
let x: [Float] = [1, 2, 3]
1412+
let tan = Array<Float>.TangentVector([1, 1, 1])
1413+
1414+
func sumReduce(_ a: [Float]) -> Float {
1415+
return a.differentiableReduce(0, { $0 + $1 })
1416+
}
1417+
expectEqual(1 + 1 + 1, differential(at: x, in: sumReduce)(tan))
1418+
1419+
func productReduce(_ a: [Float]) -> Float {
1420+
return a.differentiableReduce(1, { $0 * $1 })
1421+
}
1422+
expectEqual(x[1] * x[2] + x[0] * x[2] + x[0] * x[1], differential(at: x, in: productReduce)(tan))
1423+
1424+
func sumOfSquaresReduce(_ a: [Float]) -> Float {
1425+
return a.differentiableReduce(0, { $0 + $1 * $1 })
1426+
}
1427+
expectEqual(2 * x[0] + 2 * x[1] + 2 * x[2], differential(at: x, in: sumOfSquaresReduce)(tan))
1428+
}
1429+
13411430
runAllTests()

0 commit comments

Comments
 (0)