Skip to content

Commit 32f6629

Browse files
committed
Add tests.
Add tests for repeat-while loops and break/continue statements. Expose TF-584: incorrect derivative computation for repeat-while loops.
1 parent 16dc9a2 commit 32f6629

File tree

3 files changed

+83
-5
lines changed

3 files changed

+83
-5
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,7 @@ static void dumpActivityInfo(SILFunction &fn,
17361736
for (auto &inst : bb)
17371737
for (auto res : inst.getResults())
17381738
dumpActivityInfo(res, indices, activityInfo, s);
1739-
s << "\n";
1739+
s << '\n';
17401740
}
17411741
}
17421742

test/AutoDiff/control_flow.swift

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,23 +542,89 @@ ControlFlowTests.test("Loops") {
542542
expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop))
543543
expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop))
544544

545-
func nested_loop(_ x: Float) -> Float {
545+
func repeat_while_loop(_ x: Float) -> Float {
546+
var result = x
547+
var i = 1
548+
repeat {
549+
result = result * x
550+
i += 1
551+
} while i < 3
552+
return result
553+
}
554+
// FIXME(TF-584): Investigate incorrect (too big) gradient values
555+
// for repeat-while loops.
556+
// expectEqual((8, 12), valueWithGradient(at: 2, in: repeat_while_loop))
557+
// expectEqual((27, 27), valueWithGradient(at: 3, in: repeat_while_loop))
558+
expectEqual((8, 18), valueWithGradient(at: 2, in: repeat_while_loop))
559+
expectEqual((27, 36), valueWithGradient(at: 3, in: repeat_while_loop))
560+
561+
func loop_continue(_ x: Float) -> Float {
562+
var result = x
563+
for i in 1..<10 {
564+
if i.isMultiple(of: 2) {
565+
continue
566+
}
567+
result = result * x
568+
}
569+
return result
570+
}
571+
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_continue))
572+
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_continue))
573+
574+
func loop_break(_ x: Float) -> Float {
575+
var result = x
576+
for i in 1..<10 {
577+
if i.isMultiple(of: 2) {
578+
continue
579+
}
580+
result = result * x
581+
}
582+
return result
583+
}
584+
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_break))
585+
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_break))
586+
587+
func nested_loop1(_ x: Float) -> Float {
546588
var outer = x
547589
for _ in 1..<3 {
548590
outer = outer * x
549591

550592
var inner = outer
551593
var i = 1
552594
while i < 3 {
553-
inner = inner / x
595+
inner = inner + x
554596
i += 1
555597
}
556598
outer = inner
557599
}
558600
return outer
559601
}
560-
expectEqual((0.5, -0.25), valueWithGradient(at: 2, in: nested_loop))
561-
expectEqual((0.25, -0.0625), valueWithGradient(at: 4, in: nested_loop))
602+
expectEqual((20, 22), valueWithGradient(at: 2, in: nested_loop1))
603+
expectEqual((104, 66), valueWithGradient(at: 4, in: nested_loop1))
604+
605+
func nested_loop2(_ x: Float, count: Int) -> Float {
606+
var outer = x
607+
outerLoop: for _ in 1..<count {
608+
outer = outer * x
609+
610+
var inner = outer
611+
var i = 1
612+
while i < count {
613+
inner = inner + x
614+
i += 1
615+
616+
switch Int(inner.truncatingRemainder(dividingBy: 7)) {
617+
case 0: break outerLoop
618+
case 1: break
619+
default: continue
620+
}
621+
}
622+
outer = inner
623+
}
624+
return outer
625+
}
626+
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 5) }))
627+
expectEqual((16, 8), valueWithGradient(at: 4, in: { x in nested_loop2(x, count: 5) }))
562628
}
563629

564630
runAllTests()

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,15 @@ enum Tree : Differentiable & AdditiveArithmetic {
141141
}
142142
}
143143
}
144+
145+
// expected-error @+1 {{function is not differentiable}}
146+
@differentiable
147+
// expected-note @+1 {{when differentiating this function definition}}
148+
func loop_array(_ array: [Float]) -> Float {
149+
var result: Float = 1
150+
// expected-note @+1 {{differentiating enum values is not yet supported}}
151+
for x in array {
152+
result = result * x
153+
}
154+
return result
155+
}

0 commit comments

Comments
 (0)