@@ -542,23 +542,89 @@ ControlFlowTests.test("Loops") {
542
542
expectEqual ( ( 8 , 12 ) , valueWithGradient ( at: 2 , in: while_loop) )
543
543
expectEqual ( ( 27 , 27 ) , valueWithGradient ( at: 3 , in: while_loop) )
544
544
545
- func nested_loop( _ x: Float ) -> Float {
545
+ func repeat_while_loop( _ x: Float ) -> Float {
546
+ var result = x
547
+ var i = 1
548
+ repeat {
549
+ result = result * x
550
+ i += 1
551
+ } while i < 3
552
+ return result
553
+ }
554
+ // FIXME(TF-584): Investigate incorrect (too big) gradient values
555
+ // for repeat-while loops.
556
+ // expectEqual((8, 12), valueWithGradient(at: 2, in: repeat_while_loop))
557
+ // expectEqual((27, 27), valueWithGradient(at: 3, in: repeat_while_loop))
558
+ expectEqual ( ( 8 , 18 ) , valueWithGradient ( at: 2 , in: repeat_while_loop) )
559
+ expectEqual ( ( 27 , 36 ) , valueWithGradient ( at: 3 , in: repeat_while_loop) )
560
+
561
+ func loop_continue( _ x: Float ) -> Float {
562
+ var result = x
563
+ for i in 1 ..< 10 {
564
+ if i. isMultiple ( of: 2 ) {
565
+ continue
566
+ }
567
+ result = result * x
568
+ }
569
+ return result
570
+ }
571
+ expectEqual ( ( 64 , 192 ) , valueWithGradient ( at: 2 , in: loop_continue) )
572
+ expectEqual ( ( 729 , 1458 ) , valueWithGradient ( at: 3 , in: loop_continue) )
573
+
574
+ func loop_break( _ x: Float ) -> Float {
575
+ var result = x
576
+ for i in 1 ..< 10 {
577
+ if i. isMultiple ( of: 2 ) {
578
+ continue
579
+ }
580
+ result = result * x
581
+ }
582
+ return result
583
+ }
584
+ expectEqual ( ( 64 , 192 ) , valueWithGradient ( at: 2 , in: loop_break) )
585
+ expectEqual ( ( 729 , 1458 ) , valueWithGradient ( at: 3 , in: loop_break) )
586
+
587
+ func nested_loop1( _ x: Float ) -> Float {
546
588
var outer = x
547
589
for _ in 1 ..< 3 {
548
590
outer = outer * x
549
591
550
592
var inner = outer
551
593
var i = 1
552
594
while i < 3 {
553
- inner = inner / x
595
+ inner = inner + x
554
596
i += 1
555
597
}
556
598
outer = inner
557
599
}
558
600
return outer
559
601
}
560
- expectEqual ( ( 0.5 , - 0.25 ) , valueWithGradient ( at: 2 , in: nested_loop) )
561
- expectEqual ( ( 0.25 , - 0.0625 ) , valueWithGradient ( at: 4 , in: nested_loop) )
602
+ expectEqual ( ( 20 , 22 ) , valueWithGradient ( at: 2 , in: nested_loop1) )
603
+ expectEqual ( ( 104 , 66 ) , valueWithGradient ( at: 4 , in: nested_loop1) )
604
+
605
+ func nested_loop2( _ x: Float , count: Int ) -> Float {
606
+ var outer = x
607
+ outerLoop: for _ in 1 ..< count {
608
+ outer = outer * x
609
+
610
+ var inner = outer
611
+ var i = 1
612
+ while i < count {
613
+ inner = inner + x
614
+ i += 1
615
+
616
+ switch Int ( inner. truncatingRemainder ( dividingBy: 7 ) ) {
617
+ case 0 : break outerLoop
618
+ case 1 : break
619
+ default : continue
620
+ }
621
+ }
622
+ outer = inner
623
+ }
624
+ return outer
625
+ }
626
+ expectEqual ( ( 24 , 12 ) , valueWithGradient ( at: 2 , in: { x in nested_loop2 ( x, count: 5 ) } ) )
627
+ expectEqual ( ( 16 , 8 ) , valueWithGradient ( at: 4 , in: { x in nested_loop2 ( x, count: 5 ) } ) )
562
628
}
563
629
564
630
runAllTests ( )
0 commit comments