Skip to content

Commit d599105

Browse files
committed
[AutoDiff upstream] Fix stdlib differentiation tests.
Temporarily disable not-yet-supported differentiation tests: - Forward-mode differentiation. - TF-1237: to be upstreamed. - `Differentiable.zeroTangentVector`. - TF-1238: to be upstreamed. - `SIMD.sum` differentiation. - TF-1103: `@_alwaysEmitIntoClient` derivative registration.
1 parent 2c11214 commit d599105

File tree

5 files changed

+45
-32
lines changed

5 files changed

+45
-32
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,6 @@ extension Array: Differentiable where Element: Differentiable {
164164
view.move(along: direction)
165165
self = view.base
166166
}
167-
168-
/// A closure that produces a `TangentVector` of zeros with the same
169-
/// `count` as `self`.
170-
public var zeroTangentVectorInitializer: () -> TangentVector {
171-
{ [count = self.count] in
172-
TangentVector(.init(repeating: .zero, count: count))
173-
}
174-
}
175167
}
176168

177169
//===----------------------------------------------------------------------===//

test/AutoDiff/stdlib/array.swift

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ ArrayAutoDiffTests.test("Array.init(repeating:count:)") {
368368
Array(repeating: x, count: 10)
369369
}
370370
expectEqual(Float(10), gradient(at: .zero) { x in
371-
repeating(x).differentiableReduce(0, {$0 + $1}).value
371+
repeating(x).differentiableReduce(0, {$0 + $1})
372372
})
373373
expectEqual(Float(20), pullback(at: .zero, in: { x in
374-
repeating(x).differentiableReduce(0, {$0 + $1}).value
374+
repeating(x).differentiableReduce(0, {$0 + $1})
375375
})(2))
376376
}
377377

@@ -401,22 +401,13 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.base") {
401401
backprop(FloatArrayTan([1, 2, 3, 4])))
402402
}
403403

404-
ArrayAutoDiffTests.test("Array.DifferentiableView : KeyPathIterable") {
405-
struct Container : KeyPathIterable {
406-
let a: Array<Float>.DifferentiableView
407-
}
408-
let container = Container(a: Array<Float>.DifferentiableView([1, 2, 3]))
409-
expectEqual(
410-
[1, 2, 3],
411-
container.recursivelyAllKeyPaths(to: Float.self).map {
412-
container[keyPath: $0]
413-
})
414-
}
415-
416-
ArrayAutoDiffTests.test("Array.zeroTangentVectorInitializer") {
404+
// TODO: Upstream `Differentiable.zeroTangentVector` and implementations.
405+
/*
406+
ArrayAutoDiffTests.test("Array.zeroTangentVector") {
417407
let count = 10
418408
let array: [Float] = Array((0..<count).map(Float.init))
419409
expectEqual(array.zeroTangentVector.base, Array(repeating: 0, count: count))
420410
}
411+
*/
421412

422413
runAllTests()

test/AutoDiff/stdlib/floating_point.swift.gyb

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,33 @@ FloatingPointDerivativeTests.test("${Self}.+") {
3838
expectEqual((1, 1), gradient(at: ${Self}(4), ${Self}(5), in: +))
3939
expectEqual((10, 10), pullback(at: ${Self}(4), ${Self}(5), in: +)(${Self}(10)))
4040

41-
expectEqual(2, derivative(at: ${Self}(4), ${Self}(5), in: +))
42-
expectEqual(20, differential(at: ${Self}(4), ${Self}(5), in: +)(${Self}(10), ${Self}(10)))
41+
// TODO(TF-1237): Upstream forward-mode differentiation.
42+
expectCrash {
43+
expectEqual(2, derivative(at: ${Self}(4), ${Self}(5), in: +))
44+
expectEqual(20, differential(at: ${Self}(4), ${Self}(5), in: +)(${Self}(10), ${Self}(10)))
45+
}
4346
}
4447

4548
FloatingPointDerivativeTests.test("${Self}.-") {
4649
expectEqual((1, -1), gradient(at: ${Self}(4), ${Self}(5), in: -))
4750
expectEqual((10, -10), pullback(at: ${Self}(4), ${Self}(5), in: -)(${Self}(10)))
4851

49-
expectEqual(0, derivative(at: ${Self}(4), ${Self}(5), in: -))
50-
expectEqual(-5, differential(at: ${Self}(4), ${Self}(5), in: -)(${Self}(5), ${Self}(10)))
52+
// TODO(TF-1237): Upstream forward-mode differentiation.
53+
expectCrash {
54+
expectEqual(0, derivative(at: ${Self}(4), ${Self}(5), in: -))
55+
expectEqual(-5, differential(at: ${Self}(4), ${Self}(5), in: -)(${Self}(5), ${Self}(10)))
56+
}
5157
}
5258

5359
FloatingPointDerivativeTests.test("${Self}.*") {
5460
expectEqual((5, 4), gradient(at: ${Self}(4), ${Self}(5), in: *))
5561
expectEqual((50, 40), pullback(at: ${Self}(4), ${Self}(5), in: *)(${Self}(10)))
5662

57-
expectEqual(9, derivative(at: ${Self}(4), ${Self}(5), in: *))
58-
expectEqual(90, differential(at: ${Self}(4), ${Self}(5), in: *)(${Self}(10), ${Self}(10)))
63+
// TODO(TF-1237): Upstream forward-mode differentiation.
64+
expectCrash {
65+
expectEqual(9, derivative(at: ${Self}(4), ${Self}(5), in: *))
66+
expectEqual(90, differential(at: ${Self}(4), ${Self}(5), in: *)(${Self}(10), ${Self}(10)))
67+
}
5968
}
6069

6170
FloatingPointDerivativeTests.test("${Self}./") {
@@ -70,8 +79,11 @@ FloatingPointDerivativeTests.test("${Self}./") {
7079
expectEqualWithTolerance(-1.6, dy)
7180
}
7281

73-
expectEqualWithTolerance(0.04, derivative(at: ${Self}(4), ${Self}(5), in: /))
74-
expectEqual(90, differential(at: ${Self}(4), ${Self}(5), in: *)(${Self}(10), ${Self}(10)))
82+
// TODO(TF-1237): Upstream forward-mode differentiation.
83+
expectCrash {
84+
expectEqualWithTolerance(0.04, derivative(at: ${Self}(4), ${Self}(5), in: /))
85+
expectEqual(90, differential(at: ${Self}(4), ${Self}(5), in: *)(${Self}(10), ${Self}(10)))
86+
}
7587
}
7688

7789
FloatingPointDerivativeTests.test("${Self}.squareRoot") {

test/AutoDiff/stdlib/simd.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ SIMDTests.test("init(repeating:)") {
2222
expectEqual(8, bp1(g))
2323
}
2424

25+
// FIXME(TF-1103): Derivative registration does not yet support
26+
// `@_alwaysEmitIntoClient` original functions.
27+
/*
2528
SIMDTests.test("Sum") {
2629
let a = SIMD4<Float>(1, 2, 3, 4)
2730

@@ -32,6 +35,7 @@ SIMDTests.test("Sum") {
3235
expectEqual(10, val1)
3336
expectEqual(SIMD4<Float>(3, 3, 3, 3), bp1(3))
3437
}
38+
*/
3539

3640
SIMDTests.test("Identity") {
3741
let a = SIMD4<Float>(1, 2, 3, 4)
@@ -259,6 +263,8 @@ SIMDTests.test("Generics") {
259263
expectEqual(SIMD3<Double>(5, 10, 15), val4)
260264
expectEqual((SIMD3<Double>(5, 5, 5), 6), bp4(g))
261265

266+
// FIXME(TF-1103): Derivative registration does not yet support
267+
/*
262268
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
263269
where SIMDType.Scalar == Scalar,
264270
SIMDType : Differentiable,
@@ -271,6 +277,7 @@ SIMDTests.test("Generics") {
271277
let (val5, bp5) = valueWithPullback(at: a, in: simd3Sum)
272278
expectEqual(6, val5)
273279
expectEqual(SIMD3<Double>(7, 7, 7), bp5(7))
280+
*/
274281
}
275282

276283
runAllTests()

test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ where T == T.TangentVector {
5555
%for T in ['Float', 'Float80']:
5656

5757
DerivativeTests.test("${op}_${T}") {
58+
%if op == 'derivative':
59+
// TODO(TF-1237): Upstream forward-mode differentiation and uncomment the next line.
60+
expectCrash {
61+
%end
5862
expectEqualWithTolerance(7.3890560989306502274, ${op}(at: 2 as ${T}, in: exp))
5963
expectEqualWithTolerance(2.772588722239781145, ${op}(at: 2 as ${T}, in: exp2))
6064
expectEqualWithTolerance(7.3890560989306502274, ${op}(at: 2 as ${T}, in: expm1))
@@ -82,6 +86,10 @@ DerivativeTests.test("${op}_${T}") {
8286
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { round($0) }))
8387
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { trunc($0) }))
8488

89+
%if op == 'derivative':
90+
} // `expectCrash` for forward-mode differentiation
91+
%end
92+
8593
// Differential operator specific tests.
8694

8795
// fma
@@ -91,7 +99,10 @@ DerivativeTests.test("${op}_${T}") {
9199
expectEqualWithTolerance(4, dfma.1)
92100
expectEqualWithTolerance(1, dfma.2)
93101
%else: # if op == 'derivative'
94-
expectEqualWithTolerance(10, dfma)
102+
// TODO(TF-1237): Upstream forward-mode differentiation.
103+
expectCrash {
104+
expectEqualWithTolerance(10, dfma)
105+
}
95106
%end
96107

97108
// remainder, fmod

0 commit comments

Comments
 (0)