Skip to content

Commit 7af00ee

Browse files
authored
[AutoDiff] Add negative tests for TF-954. (#28223)
Add negative tests for TF-954: addresses incorrectly not marked useful. - Activity analysis test for TF-954 minimal reproducer. - Correctness tests for TF-954 minimal reproducer and existing reproducer.
1 parent 439808d commit 7af00ee

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

test/AutoDiff/activity_analysis.swift

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func testNondifferentiableTupleElementAddr<T>(_ x: T) -> T {
5757
// CHECK: [ACTIVE] %56 = tuple_element_addr %55 : $*(Int, Int, (T, Int), Int), 2
5858
// CHECK: [ACTIVE] %57 = tuple_element_addr %56 : $*(T, Int), 0
5959

60-
// Check activity analysis for `array.uninitialized_intrinsic` applications.
60+
// Check `array.uninitialized_intrinsic` applications.
6161

6262
@differentiable
6363
func testArrayUninitializedIntrinsic(_ x: Float, _ y: Float) -> [Float] {
@@ -93,10 +93,10 @@ func testArrayUninitializedIntrinsicGeneric<T>(_ x: T, _ y: T) -> [T] {
9393
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
9494
// CHECK: [VARIED] %12 = index_addr %9 : $*T, %11 : $Builtin.Word
9595

96-
// TF-781: check activity analysis for active local address + nested conditionals.
96+
// TF-781: check active local address + nested conditionals.
9797

9898
@differentiable(wrt: x)
99-
func TF_781_function(_ x: Float, _ y: Float) -> Float {
99+
func TF_781(_ x: Float, _ y: Float) -> Float {
100100
var result = y
101101
if true {
102102
if true {
@@ -106,7 +106,7 @@ func TF_781_function(_ x: Float, _ y: Float) -> Float {
106106
return result
107107
}
108108

109-
// CHECK-LABEL: [AD] Activity info for ${{.*}}TF_781_function{{.*}} at (source=0 parameters=(0))
109+
// CHECK-LABEL: [AD] Activity info for ${{.*}}TF_781{{.*}} at (source=0 parameters=(0))
110110
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
111111
// CHECK: [USEFUL] %1 = argument of bb0 : $Float
112112
// CHECK: [ACTIVE] %4 = alloc_stack $Float, var, name "result"
@@ -116,3 +116,41 @@ func TF_781_function(_ x: Float, _ y: Float) -> Float {
116116
// CHECK: [ACTIVE] %24 = begin_access [modify] [static] %4 : $*Float
117117
// CHECK: [ACTIVE] %31 = begin_access [read] [static] %4 : $*Float
118118
// CHECK: [ACTIVE] %32 = load [trivial] %31 : $*Float
119+
120+
// TF-954: check nested conditionals and addresses.
121+
122+
@differentiable
123+
func TF_954(_ x: Float) -> Float {
124+
var outer = x
125+
outerIf: if true {
126+
var inner = outer
127+
inner = inner * x // check activity of this `apply`
128+
if false {
129+
break outerIf
130+
}
131+
outer = inner
132+
}
133+
return outer
134+
}
135+
136+
// CHECK-LABEL: [AD] Activity info for ${{.*}}TF_954{{.*}} at (source=0 parameters=(0))
137+
// CHECK: bb0:
138+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
139+
// CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "outer"
140+
// CHECK: bb1:
141+
// CHECK: [ACTIVE] %10 = alloc_stack $Float, var, name "inner"
142+
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Float
143+
// CHECK: [NONE] %14 = metatype $@thin Float.Type
144+
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %10 : $*Float
145+
// CHECK: [VARIED] %16 = load [trivial] %15 : $*Float
146+
// CHECK: [NONE] // function_ref static Float.* infix(_:_:)
147+
// CHECK: %18 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
148+
// CHECK: [VARIED] %19 = apply %18(%16, %0, %14) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
149+
// CHECK: [ACTIVE] %20 = begin_access [modify] [static] %10 : $*Float
150+
// CHECK: bb3:
151+
// CHECK: [ACTIVE] %31 = begin_access [read] [static] %10 : $*Float
152+
// CHECK: [ACTIVE] %32 = load [trivial] %31 : $*Float
153+
// CHECK: [ACTIVE] %34 = begin_access [modify] [static] %2 : $*Float
154+
// CHECK: bb5:
155+
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
156+
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float

test/AutoDiff/control_flow.swift

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ ControlFlowTests.test("Conditionals") {
5656
expectEqual((5, 4), gradient(at: 4, 5, in: cond3))
5757
expectEqual((-1, 1), gradient(at: -3, -2, in: cond3))
5858

59+
func cond4_var(_ x: Float) -> Float {
60+
var outer = x
61+
outerIf: if true {
62+
var inner = outer
63+
inner = inner * x
64+
if false {
65+
break outerIf
66+
}
67+
outer = inner
68+
}
69+
return outer
70+
}
71+
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
72+
// nested control flow.
73+
// expectEqual((9, 6), valueWithGradient(at: 3, in: cond4_var))
74+
expectEqual((9, 0), valueWithGradient(at: 3, in: cond4_var))
75+
5976
func cond_tuple(_ x: Float) -> Float {
6077
// Convoluted function returning `x + x`.
6178
let y: (Float, Float) = (x, x)
@@ -690,8 +707,16 @@ ControlFlowTests.test("Loops") {
690707
}
691708
return outer
692709
}
710+
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
711+
// nested control flow.
712+
// expectEqual((6, 5), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
713+
// expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
714+
// expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
715+
// expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
716+
expectEqual((6, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
717+
expectEqual((20, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
718+
expectEqual((52, 26), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
693719
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
694-
expectEqual((16, 8), valueWithGradient(at: 4, in: { x in nested_loop2(x, count: 4) }))
695720
}
696721

697722
runAllTests()

0 commit comments

Comments
 (0)