@@ -41,6 +41,116 @@ final class TensorAutoDiffTests: XCTestCase {
41
41
XCTAssertEqual ( gradient ( at: Tensor ( [ 0.1 , 0.2 , 0.3 ] ) , in: square) , [ 0.2 , 0.4 , 0.6 ] )
42
42
}
43
43
44
+ func testConditionals( ) {
45
+ func condNestedTupleVar( _ x: Tensor < Float > ) -> Tensor < Float > {
46
+ // Convoluted function returning `x + x`.
47
+ var y : ( Tensor < Float > , Tensor < Float > ) = ( x + x, x - x)
48
+ var z : ( ( Tensor < Float > , Tensor < Float > ) , Tensor < Float > ) = ( y, x)
49
+ if ( x .> 0 ) . all ( ) {
50
+ let w = ( x, x)
51
+ y. 0 = w. 1
52
+ y. 1 = w. 0
53
+ z. 0 . 0 = z. 0 . 0 - y. 0
54
+ z. 0 . 1 = z. 0 . 1 + y. 0
55
+ } else {
56
+ z = ( ( y. 0 - x, y. 1 + x) , x)
57
+ }
58
+ return y. 0 + y. 1 - z. 0 . 0 + z. 0 . 1
59
+ }
60
+ XCTAssertTrue ( ( value: Tensor ( 8 ) , gradient: Tensor ( 2 ) ) == valueWithGradient ( at: Tensor ( 4 ) , in: condNestedTupleVar) )
61
+ XCTAssertTrue ( ( value: Tensor ( - 20 ) , gradient: Tensor ( 2 ) ) == valueWithGradient ( at: Tensor ( - 10 ) , in: condNestedTupleVar) )
62
+ XCTAssertTrue ( ( value: Tensor ( - 2674 ) , gradient: Tensor ( 2 ) ) == valueWithGradient ( at: Tensor ( - 1337 ) , in: condNestedTupleVar) )
63
+
64
+ func guard2Var( _ x: Tensor < Float > , _ y: Tensor < Float > ) -> Tensor < Float > {
65
+ var z = y
66
+ guard ( x .> 0 ) . all ( ) else {
67
+ if ( y .> 0 ) . all ( ) {
68
+ z = z * x
69
+ } else if x == Tensor ( - 1337 ) {
70
+ z = x
71
+ z = z * z
72
+ } else {
73
+ z = Tensor ( 0 )
74
+ }
75
+ return z
76
+ }
77
+ return z * y
78
+ }
79
+ XCTAssertTrue ( ( Tensor ( 0 ) , Tensor ( 10 ) ) == gradient ( at: Tensor ( 4 ) , Tensor ( 5 ) , in: guard2Var) )
80
+ XCTAssertTrue ( ( Tensor ( 5 ) , Tensor ( - 1337 ) ) == gradient ( at: Tensor ( - 1337 ) , Tensor ( 5 ) , in: guard2Var) )
81
+ XCTAssertTrue ( ( Tensor ( - 2674 ) , Tensor ( 0 ) ) == gradient ( at: Tensor ( - 1337 ) , Tensor ( - 5 ) , in: guard2Var) )
82
+ XCTAssertTrue ( ( Tensor ( 2 ) , Tensor ( - 3 ) ) == gradient ( at: Tensor ( - 3 ) , Tensor ( 2 ) , in: guard2Var) )
83
+ }
84
+
85
+ func testNestedConditionals( ) {
86
+ // Test tensor-tensor ops.
87
+ func condNested1( _ x: Tensor < Float > , _ y: Tensor < Float > ) -> Tensor < Float > {
88
+ if ( x .> 0 ) . all ( ) {
89
+ if ( y .> 10 ) . all ( ) {
90
+ let z = x * y
91
+ if ( z .> 100 ) . all ( ) {
92
+ return x + z
93
+ } else if y == Tensor ( 20 ) {
94
+ return z + z
95
+ }
96
+ } else {
97
+ return x + y
98
+ }
99
+ }
100
+ return - y
101
+ }
102
+ XCTAssertTrue ( ( Tensor ( 40 ) , Tensor ( 8 ) ) == gradient ( at: Tensor ( 4 ) , Tensor ( 20 ) , in: condNested1) )
103
+ XCTAssertTrue ( ( Tensor ( 0 ) , Tensor ( - 1 ) ) == gradient ( at: Tensor ( 4 ) , Tensor ( 21 ) , in: condNested1) )
104
+ XCTAssertTrue ( ( Tensor ( 1 ) , Tensor ( 1 ) ) == gradient ( at: Tensor ( 4 ) , Tensor ( 5 ) , in: condNested1) )
105
+ XCTAssertTrue ( ( Tensor ( 0 ) , Tensor ( - 1 ) ) == gradient ( at: Tensor ( - 3 ) , Tensor ( - 2 ) , in: condNested1) )
106
+
107
+ // Test tensor-scalar ops.
108
+ func condNested2( _ x: Tensor < Float > , _ y: Float ) -> Tensor < Float > {
109
+ if ( x .> 0 ) . all ( ) {
110
+ if y > 10 {
111
+ let z = x * y
112
+ if ( z .> 100 ) . all ( ) {
113
+ return x + z
114
+ } else if y == 20 {
115
+ return z + z
116
+ }
117
+ } else {
118
+ return x + y
119
+ }
120
+ }
121
+ return Tensor ( - y)
122
+ }
123
+ XCTAssertTrue ( ( Tensor ( 40 ) , 8 ) == gradient ( at: Tensor ( 4 ) , 20 , in: condNested2) )
124
+ XCTAssertTrue ( ( Tensor ( 0 ) , - 1 ) == gradient ( at: Tensor ( 4 ) , 21 , in: condNested2) )
125
+ XCTAssertTrue ( ( Tensor ( 1 ) , 1 ) == gradient ( at: Tensor ( 4 ) , 5 , in: condNested2) )
126
+ XCTAssertTrue ( ( Tensor ( 0 ) , - 1 ) == gradient ( at: Tensor ( - 3 ) , - 2 , in: condNested2) )
127
+ }
128
+
129
+ func testRecursion( ) {
130
+ func factorial( _ x: Tensor < Float > ) -> Tensor < Float > {
131
+ if x == Tensor ( 1 ) {
132
+ return Tensor ( 1 )
133
+ }
134
+ return x * factorial( x - 1 )
135
+ }
136
+ XCTAssertEqual ( gradient ( at: Tensor ( 1 ) , in: factorial) , Tensor ( 0 ) )
137
+ XCTAssertEqual ( gradient ( at: Tensor ( 2 ) , in: factorial) , Tensor ( 1 ) )
138
+ XCTAssertEqual ( gradient ( at: Tensor ( 3 ) , in: factorial) , Tensor ( 5 ) )
139
+ XCTAssertEqual ( gradient ( at: Tensor ( 4 ) , in: factorial) , Tensor ( 26 ) )
140
+ XCTAssertEqual ( gradient ( at: Tensor ( 5 ) , in: factorial) , Tensor ( 154 ) )
141
+
142
+ func product( _ x: Tensor < Float > , count: Int ) -> Tensor < Float > {
143
+ precondition ( count > 0 )
144
+ if count == 1 {
145
+ return x
146
+ }
147
+ return x * product( x, count: count - 1 )
148
+ }
149
+ XCTAssertEqual ( gradient ( at: Tensor ( - 10 ) , in: { x in product ( x, count: 2 ) } ) , Tensor ( - 20 ) )
150
+ XCTAssertEqual ( gradient ( at: Tensor ( 10 ) , in: { x in product ( x, count: 3 ) } ) , Tensor ( 300 ) )
151
+ XCTAssertEqual ( gradient ( at: Tensor ( 100 ) , in: { x in product ( x, count: 1 ) } ) , Tensor ( 1 ) )
152
+ }
153
+
44
154
func testScalarGenericGrad( ) {
45
155
// Tests TF-287.
46
156
func negate< T : TensorFlowFloatingPoint > ( _ x: Tensor < T > ) -> Tensor < T > {
@@ -567,6 +677,9 @@ final class TensorAutoDiffTests: XCTestCase {
567
677
static var allTests = [
568
678
( " testSimpleGrad " , testSimpleGrad) ,
569
679
( " testGenericGrad " , testGenericGrad) ,
680
+ ( " testConditionals " , testConditionals) ,
681
+ ( " testNestedConditionals " , testNestedConditionals) ,
682
+ ( " testRecursion " , testRecursion) ,
570
683
( " testScalarGenericGrad " , testScalarGenericGrad) ,
571
684
( " testScalarized " , testScalarized) ,
572
685
( " testScalars " , testScalars) ,
0 commit comments