Skip to content

Commit b6d7653

Browse files
committed
[AutoDiff upstream] Test differentiable collection higher-order functions.
Test `Array.differentiableMap(_:)` and `Array.differentiableReduce(_:_:)`.
1 parent d599105 commit b6d7653

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
import StdlibUnittest
6+
7+
// Test differentiable collection higher order functions:
8+
// `differentiableMap(_:)` and `differentiableReduce(_:_:)`.
9+
10+
var CollectionHOFTests = TestSuite("CollectionHigherOrderFunctions")
11+
12+
let array: [Float] = [1, 2, 3, 4, 5]
13+
14+
CollectionHOFTests.test("differentiableMap(_:)") {
15+
func double(_ array: [Float]) -> [Float] {
16+
array.differentiableMap { $0 * $0 }
17+
}
18+
expectEqual([], pullback(at: array, in: double)([]))
19+
expectEqual([0], pullback(at: array, in: double)([0]))
20+
expectEqual([2], pullback(at: array, in: double)([1]))
21+
expectEqual([2, 4, 6, 8, 10], pullback(at: array, in: double)([1, 1, 1, 1, 1]))
22+
}
23+
24+
CollectionHOFTests.test("differentiableReduce(_:_:)") {
25+
func product(_ array: [Float]) -> Float {
26+
array.differentiableReduce(1) { $0 * $1 }
27+
}
28+
expectEqual([1], gradient(at: [0], in: product))
29+
expectEqual([1], gradient(at: [1], in: product))
30+
expectEqual([120, 60, 40, 30, 24], gradient(at: array, in: product))
31+
}
32+
33+
runAllTests()

0 commit comments

Comments
 (0)