|
1 | 1 | // RUN: %empty-directory(%t)
|
2 |
| -// RUN: %target-build-swift -enable-experimental-differentiable-programming %s -emit-sil |
3 |
| -// RUN: %target-build-swift -enable-experimental-differentiable-programming %s -o %t/differential_operators |
| 2 | +// RUN: %gyb %s -o %t/differential_operators.swift |
| 3 | +// RUN: %target-build-swift -enable-experimental-differentiable-programming %t/differential_operators.swift -o %t/differential_operators |
4 | 4 | // RUN: %target-run %t/differential_operators
|
5 | 5 | // REQUIRES: executable_test
|
6 | 6 |
|
7 | 7 | import _Differentiation
|
8 | 8 |
|
9 | 9 | import StdlibUnittest
|
10 | 10 |
|
11 |
| -var DifferentialOperatorTestSuite = TestSuite("Tuple") |
| 11 | +var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator") |
12 | 12 |
|
13 |
| -DifferentialOperatorTestSuite.test("foo") { |
14 |
| - expectEqual(1, 1) |
| 13 | +% for arity in range(1, 3 + 1): |
| 14 | + |
| 15 | +% params = ', '.join(['_ x%d: Float' % i for i in range(arity)]) |
| 16 | +% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')' |
| 17 | +func exampleVJP_${arity}(${params}) -> (Float, (Float) -> ${pb_return_type}) { |
| 18 | + ( |
| 19 | + ${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])}, |
| 20 | + { (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) } |
| 21 | + ) |
| 22 | +} |
| 23 | + |
| 24 | +% argValues = [i * 10 for i in range(1, arity + 1)] |
| 25 | +% args = ', '.join([str(v) for v in argValues]) |
| 26 | +% expectedValue = sum([v * v for v in argValues]) |
| 27 | +% expectedGradientValues = [2 * v for v in argValues] |
| 28 | +% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')' |
| 29 | + |
| 30 | +DifferentialOperatorTestSuite.test("differentiableFunction_callOriginal_${arity}") { |
| 31 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 32 | + expectEqual(${expectedValue}, f(${args})) |
15 | 33 | }
|
16 | 34 |
|
17 |
| -func exampleVJP(_ x: Float) -> (Float, (Float) -> Float) { |
18 |
| - (x * x, { 2 * $0 }) |
| 35 | +DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") { |
| 36 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 37 | + let (value, pb) = valueWithPullback(at: ${args}, in: f) |
| 38 | + expectEqual(${expectedValue}, value) |
| 39 | + expectEqual(${expectedGradients}, pb(1)) |
19 | 40 | }
|
20 | 41 |
|
21 |
| -DifferentialOperatorTestSuite.test("differentiableFunction_callOriginal") { |
22 |
| - expectEqual(differentiableFunction(from: exampleVJP)(10), 100) |
| 42 | +DifferentialOperatorTestSuite.test("pullback_${arity}") { |
| 43 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 44 | + let pb = pullback(at: ${args}, in: f) |
| 45 | + expectEqual(${expectedGradients}, pb(1)) |
23 | 46 | }
|
24 | 47 |
|
| 48 | +DifferentialOperatorTestSuite.test("gradient_${arity}") { |
| 49 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 50 | + let grad = gradient(at: ${args}, in: f) |
| 51 | + expectEqual(${expectedGradients}, grad) |
| 52 | +} |
| 53 | + |
| 54 | +DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") { |
| 55 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 56 | + let (value, grad) = valueWithGradient(at: ${args}, in: f) |
| 57 | + expectEqual(${expectedValue}, value) |
| 58 | + expectEqual(${expectedGradients}, grad) |
| 59 | +} |
| 60 | + |
| 61 | +DifferentialOperatorTestSuite.test("gradient_curried_${arity}") { |
| 62 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 63 | + let gradF = gradient(of: f) |
| 64 | + expectEqual(${expectedGradients}, gradF(${args})) |
| 65 | +} |
| 66 | + |
| 67 | +DifferentialOperatorTestSuite.test("valueWithGradient_curried_${arity}") { |
| 68 | + let f = differentiableFunction(from: exampleVJP_${arity}) |
| 69 | + let valueWithGradF = valueWithGradient(of: f) |
| 70 | + let (value, grad) = valueWithGradF(${args}) |
| 71 | + expectEqual(${expectedValue}, value) |
| 72 | + expectEqual(${expectedGradients}, grad) |
| 73 | +} |
| 74 | + |
| 75 | +% end |
| 76 | + |
25 | 77 | runAllTests()
|
0 commit comments