Skip to content

Commit ebe9223

Browse files
authored
[AutoDiff] Add more loop differentiation negative testcases for TF-933. (#27973)
Loop differentiation produces incorrect results when the reduction accumulation variable is not initialized with an active parameter. TF-933 tracks this issue. This patch is a follow-up to #27796, adding negative test cases for while and repeat-while loops.
1 parent d4ca1ac commit ebe9223

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

test/AutoDiff/control_flow.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,22 @@ ControlFlowTests.test("Loops") {
556556
expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop))
557557
expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop))
558558

559+
func while_loop_nonactive_initial_value(_ x: Float) -> Float {
560+
var result: Float = 1
561+
var i = 0
562+
while i < 2 {
563+
result = result * x
564+
i += 1
565+
}
566+
return result
567+
}
568+
// TODO(TF-933): Fix incorrect derivatives when `var result` is not initially
569+
// assigned to `x`.
570+
// expectEqual((4, 4), valueWithGradient(at: 2, in: while_loop_nonactive_initial_value))
571+
// expectEqual((9, 6), valueWithGradient(at: 3, in: while_loop_nonactive_initial_value))
572+
expectEqual((4, 2), valueWithGradient(at: 2, in: while_loop_nonactive_initial_value))
573+
expectEqual((9, 3), valueWithGradient(at: 3, in: while_loop_nonactive_initial_value))
574+
559575
func repeat_while_loop(_ x: Float) -> Float {
560576
var result = x
561577
var i = 0
@@ -572,6 +588,22 @@ ControlFlowTests.test("Loops") {
572588
expectEqual((8, 18), valueWithGradient(at: 2, in: repeat_while_loop))
573589
expectEqual((27, 36), valueWithGradient(at: 3, in: repeat_while_loop))
574590

591+
func repeat_while_loop_nonactive_initial_value(_ x: Float) -> Float {
592+
var result: Float = 1
593+
var i = 0
594+
repeat {
595+
result = result * x
596+
i += 1
597+
} while i < 2
598+
return result
599+
}
600+
// TODO(TF-584, TF-933): Fix incorrect derivatives when `var result` is not
601+
// initially assigned to `x`.
602+
// expectEqual((4, 4), valueWithGradient(at: 2, in: repeat_while_loop_nonactive_initial_value))
603+
// expectEqual((9, 6), valueWithGradient(at: 3, in: repeat_while_loop_nonactive_initial_value))
604+
expectEqual((4, 3), valueWithGradient(at: 2, in: repeat_while_loop_nonactive_initial_value))
605+
expectEqual((9, 4), valueWithGradient(at: 3, in: repeat_while_loop_nonactive_initial_value))
606+
575607
func loop_continue(_ x: Float) -> Float {
576608
var result = x
577609
for i in 1..<10 {

0 commit comments

Comments
 (0)