@@ -153,6 +153,77 @@ final class TensorAutoDiffTests: XCTestCase {
153
153
XCTAssertEqual ( varianceGradAlongAxes ( input) , expected)
154
154
}
155
155
156
+ func testMin( ) {
157
+ // The expected gradient values were computed using the following TensorFlow 2.0 Beta1
158
+ // Python code with respective `a` and `b` tensors:
159
+ // ```
160
+ // with tf.GradientTape() as t:
161
+ // t.watch([a, b])
162
+ // y = tf.math.reduce_sum(tf.minimum(a, b))
163
+ // print(t.gradient(y, [a, b]))
164
+ // ```
165
+ do {
166
+ let a = Tensor < Float > ( [ 4 , 5 , 3 ] )
167
+ let b = Tensor < Float > ( [ 4 , 2 , 6 ] )
168
+ let computedGradient1 = gradient ( at: a, b) { a, b in min ( a, b) . sum ( ) }
169
+ let expectedGradient1 : ( Tensor < Float > , Tensor < Float > ) = (
170
+ [ 1.0 , 0.0 , 1.0 ] , [ 0.0 , 1.0 , 0.0 ] )
171
+ XCTAssertEqual ( computedGradient1. 0 , expectedGradient1. 0 )
172
+ XCTAssertEqual ( computedGradient1. 1 , expectedGradient1. 1 )
173
+
174
+ let computedGradient2 = gradient ( at: a, b) { a, b in min ( b, a) . sum ( ) }
175
+ let expectedGradient2 : ( Tensor < Float > , Tensor < Float > ) = (
176
+ [ 0.0 , 0.0 , 1.0 ] , [ 1.0 , 1.0 , 0.0 ] )
177
+ XCTAssertEqual ( computedGradient2. 0 , expectedGradient2. 0 )
178
+ XCTAssertEqual ( computedGradient2. 1 , expectedGradient2. 1 )
179
+ }
180
+
181
+ do {
182
+ let a = Tensor < Float > ( [ [ 3.0 , - 2.0 ] , [ 0.3 , 10.0 ] ] )
183
+ let b = Tensor < Float > ( [ 9.0 , - 3.0 ] )
184
+ let computedGradient = gradient ( at: a, b) { a, b in min ( a, b) . sum ( ) }
185
+ let expectedGradient : ( Tensor < Float > , Tensor < Float > ) = (
186
+ [ [ 1.0 , 0.0 ] , [ 1.0 , 0.0 ] ] , [ 0.0 , 2.0 ] )
187
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient. 0 )
188
+ XCTAssertEqual ( computedGradient. 1 , expectedGradient. 1 )
189
+ }
190
+ }
191
+
192
+ func testMax( ) {
193
+ // The expected gradient values were computed using the following TensorFlow 2.0 Beta1
194
+ // Python code with respective `a` and `b` tensors:
195
+ // ```
196
+ // with tf.GradientTape() as t:
197
+ // t.watch([a, b])
198
+ // y = tf.math.reduce_sum(tf.maximum(a, b))
199
+ // print(t.gradient(y, [a, b]))
200
+ // ```
201
+ do {
202
+ let a = Tensor < Float > ( [ 4 , 5 , 3 ] )
203
+ let b = Tensor < Float > ( [ 4 , 2 , 6 ] )
204
+ let computedGradient1 = gradient ( at: a, b) { a, b in max ( a, b) . sum ( ) }
205
+ let expectedGradient1 : ( Tensor < Float > , Tensor < Float > ) = (
206
+ [ 1.0 , 1.0 , 0.0 ] , [ 0.0 , 0.0 , 1.0 ] )
207
+ XCTAssertEqual ( computedGradient1. 0 , expectedGradient1. 0 )
208
+ XCTAssertEqual ( computedGradient1. 1 , expectedGradient1. 1 )
209
+
210
+ let computedGradient2 = gradient ( at: a, b) { a, b in max ( b, a) . sum ( ) }
211
+ let expectedGradient2 : ( Tensor < Float > , Tensor < Float > ) = (
212
+ [ 0.0 , 1.0 , 0.0 ] , [ 1.0 , 0.0 , 1.0 ] )
213
+ XCTAssertEqual ( computedGradient2. 0 , expectedGradient2. 0 )
214
+ XCTAssertEqual ( computedGradient2. 1 , expectedGradient2. 1 )
215
+ }
216
+ do {
217
+ let a = Tensor < Float > ( [ [ 3.0 , - 2.0 ] , [ 0.3 , 10.0 ] ] )
218
+ let b = Tensor < Float > ( [ 9.0 , - 3.0 ] )
219
+ let computedGradient = gradient ( at: a, b) { a, b in max ( a, b) . sum ( ) }
220
+ let expectedGradient : ( Tensor < Float > , Tensor < Float > ) = (
221
+ [ [ 0.0 , 1.0 ] , [ 0.0 , 1.0 ] ] , [ 2.0 , 0.0 ] )
222
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient. 0 )
223
+ XCTAssertEqual ( computedGradient. 1 , expectedGradient. 1 )
224
+ }
225
+ }
226
+
156
227
/*TODO:(https://bugs.swift.org/browse/TF-771): Disabling this case as assertions fail.
157
228
func testTensorInitStacking() {
158
229
let a1 = Tensor<Float>([1, 2, 3, 4, 5])
@@ -449,6 +520,8 @@ final class TensorAutoDiffTests: XCTestCase {
449
520
( " testSum " , testSum) ,
450
521
( " testMean " , testMean) ,
451
522
( " testVariance " , testVariance) ,
523
+ ( " testMin " , testMin) ,
524
+ ( " testMax " , testMax) ,
452
525
// TODO(https://bugs.swift.org/browse/TF-771): Disabling the failing test.
453
526
// ("testTensorInitStacking", testTensorInitStacking),
454
527
( " testExpandingShape " , testExpandingShape) ,
0 commit comments