Skip to content

Commit 64d0e76

Browse files
author
Marc Rasi
committed
improve differential operator tests
1 parent 0218844 commit 64d0e76

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

stdlib/public/Differentiation/DifferentiationUtilities.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ public func differentiableFunction<T, U, R>(
5252
/*vjp*/ vjp)
5353
}
5454

55+
/// Create a differentiable function from a vector-Jacobian products function.
56+
@inlinable
57+
public func differentiableFunction<T, U, V, R>(
58+
from vjp: @escaping (T, U, V)
59+
-> (value: R, pullback: (R.TangentVector)
60+
-> (T.TangentVector, U.TangentVector, V.TangentVector))
61+
) -> @differentiable (T, U, V) -> R {
62+
Builtin.differentiableFunction_arity3(
63+
/*original*/ { vjp($0, $1, $2).value },
64+
/*jvp*/ { _, _, _ in
65+
fatalError("""
66+
Functions formed with `differentiableFunction(from:)` cannot yet \
67+
be used with differential-producing differential operators.
68+
""")
69+
},
70+
/*vjp*/ vjp)
71+
}
72+
5573
/// Returns `x` like an identity function. When used in a context where `x` is
5674
/// being differentiated with respect to, this function will not produce any
5775
/// derivative at `x`.
Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,77 @@
11
// RUN: %empty-directory(%t)
2-
// RUN: %target-build-swift -enable-experimental-differentiable-programming %s -emit-sil
3-
// RUN: %target-build-swift -enable-experimental-differentiable-programming %s -o %t/differential_operators
2+
// RUN: %gyb %s -o %t/differential_operators.swift
3+
// RUN: %target-build-swift -enable-experimental-differentiable-programming %t/differential_operators.swift -o %t/differential_operators
44
// RUN: %target-run %t/differential_operators
55
// REQUIRES: executable_test
66

77
import _Differentiation
88

99
import StdlibUnittest
1010

11-
var DifferentialOperatorTestSuite = TestSuite("Tuple")
11+
var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")
1212

13-
DifferentialOperatorTestSuite.test("foo") {
14-
expectEqual(1, 1)
13+
% for arity in range(1, 3 + 1):
14+
15+
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
16+
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
17+
func exampleVJP_${arity}(${params}) -> (Float, (Float) -> ${pb_return_type}) {
18+
(
19+
${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
20+
{ (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
21+
)
22+
}
23+
24+
% argValues = [i * 10 for i in range(1, arity + 1)]
25+
% args = ', '.join([str(v) for v in argValues])
26+
% expectedValue = sum([v * v for v in argValues])
27+
% expectedGradientValues = [2 * v for v in argValues]
28+
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'
29+
30+
DifferentialOperatorTestSuite.test("differentiableFunction_callOriginal_${arity}") {
31+
let f = differentiableFunction(from: exampleVJP_${arity})
32+
expectEqual(${expectedValue}, f(${args}))
1533
}
1634

17-
func exampleVJP(_ x: Float) -> (Float, (Float) -> Float) {
18-
(x * x, { 2 * $0 })
35+
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
36+
let f = differentiableFunction(from: exampleVJP_${arity})
37+
let (value, pb) = valueWithPullback(at: ${args}, in: f)
38+
expectEqual(${expectedValue}, value)
39+
expectEqual(${expectedGradients}, pb(1))
1940
}
2041

21-
DifferentialOperatorTestSuite.test("differentiableFunction_callOriginal") {
22-
expectEqual(differentiableFunction(from: exampleVJP)(10), 100)
42+
DifferentialOperatorTestSuite.test("pullback_${arity}") {
43+
let f = differentiableFunction(from: exampleVJP_${arity})
44+
let pb = pullback(at: ${args}, in: f)
45+
expectEqual(${expectedGradients}, pb(1))
2346
}
2447

48+
DifferentialOperatorTestSuite.test("gradient_${arity}") {
49+
let f = differentiableFunction(from: exampleVJP_${arity})
50+
let grad = gradient(at: ${args}, in: f)
51+
expectEqual(${expectedGradients}, grad)
52+
}
53+
54+
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
55+
let f = differentiableFunction(from: exampleVJP_${arity})
56+
let (value, grad) = valueWithGradient(at: ${args}, in: f)
57+
expectEqual(${expectedValue}, value)
58+
expectEqual(${expectedGradients}, grad)
59+
}
60+
61+
DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
62+
let f = differentiableFunction(from: exampleVJP_${arity})
63+
let gradF = gradient(of: f)
64+
expectEqual(${expectedGradients}, gradF(${args}))
65+
}
66+
67+
DifferentialOperatorTestSuite.test("valueWithGradient_curried_${arity}") {
68+
let f = differentiableFunction(from: exampleVJP_${arity})
69+
let valueWithGradF = valueWithGradient(of: f)
70+
let (value, grad) = valueWithGradF(${args})
71+
expectEqual(${expectedValue}, value)
72+
expectEqual(${expectedGradients}, grad)
73+
}
74+
75+
% end
76+
2577
runAllTests()

0 commit comments

Comments
 (0)