Skip to content

Commit 18d73fa

Browse files
authored
[AutoDiff] Expose failing control flow + recursion + buffer AD test. (#25252)
Documented at TF-554.
1 parent 6bf823b commit 18d73fa

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

test/AutoDiff/control_flow.swift

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,21 +338,47 @@ ControlFlowTests.test("Recursion") {
338338
expectEqual(26, gradient(at: 4, in: factorial))
339339
expectEqual(154, gradient(at: 5, in: factorial))
340340

341-
func factorial_var(_ x: Float) -> Float {
341+
func factorial_var1(_ x: Float) -> Float {
342+
var y: Float = x
343+
if x == 1 {
344+
y = 1
345+
} else {
346+
y = x
347+
y = y * factorial_var1(y - 1)
348+
}
349+
return y
350+
}
351+
expectEqual(0, gradient(at: 1, in: factorial_var1))
352+
expectEqual(1, gradient(at: 2, in: factorial_var1))
353+
expectEqual(5, gradient(at: 3, in: factorial_var1))
354+
expectEqual(26, gradient(at: 4, in: factorial_var1))
355+
expectEqual(154, gradient(at: 5, in: factorial_var1))
356+
357+
func factorial_var2(_ x: Float) -> Float {
358+
// Next line is the only difference with `factorial_var1`.
342359
var y: Float = 1
343360
if x == 1 {
344361
y = 1
345362
} else {
346363
y = x
347-
y = y * factorial(y - 1)
364+
y = y * factorial_var2(y - 1)
348365
}
349366
return y
350367
}
351-
expectEqual(0, gradient(at: 1, in: factorial))
352-
expectEqual(1, gradient(at: 2, in: factorial))
353-
expectEqual(5, gradient(at: 3, in: factorial))
354-
expectEqual(26, gradient(at: 4, in: factorial))
355-
expectEqual(154, gradient(at: 5, in: factorial))
368+
// FIXME: Fix zero gradients (related to activity analysis).
369+
// See `factorial_var1` for the working version.
370+
/*
371+
expectEqual(0, gradient(at: 1, in: factorial_var2))
372+
expectEqual(1, gradient(at: 2, in: factorial_var2))
373+
expectEqual(5, gradient(at: 3, in: factorial_var2))
374+
expectEqual(26, gradient(at: 4, in: factorial_var2))
375+
expectEqual(154, gradient(at: 5, in: factorial_var2))
376+
*/
377+
expectEqual(0, gradient(at: 1, in: factorial_var2))
378+
expectEqual(0, gradient(at: 2, in: factorial_var2))
379+
expectEqual(0, gradient(at: 3, in: factorial_var2))
380+
expectEqual(0, gradient(at: 4, in: factorial_var2))
381+
expectEqual(0, gradient(at: 5, in: factorial_var2))
356382

357383
func product(_ x: Float, count: Int) -> Float {
358384
precondition(count > 0)

0 commit comments

Comments
 (0)