Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b7f2a06

Browse files
jon-towdan-zheng
authored andcommitted
Update min(_:_:) and max(_:_:) gradients to match Python TensorFlow (#480)
1 parent c7595c4 commit b7f2a06

File tree

4 files changed

+103
-29
lines changed

4 files changed

+103
-29
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,6 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
287287
) -> Tensor<Scalar> {
288288
// This numerically stable implementation is based on the TensorFlow Python API.
289289
let maxLogitsWithZero = max(logits, Tensor(0))
290-
// Note: `result` is split into two lines to avoid the "compiler is unable to type-check this
291-
// expression in reasonable time" error.
292-
let result = log(1 + exp(-abs(logits)))
293-
return reduction(maxLogitsWithZero - logits * labels + result)
290+
let negAbsLogits = max(logits, -logits) // Custom `abs` to compute gradients at `0`.
291+
return reduction(maxLogitsWithZero - logits * labels + log1p(exp(-negAbsLogits)))
294292
}

Sources/TensorFlow/Operators/Math.swift

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,9 @@ internal func _vjpMax<T: TensorFlowFloatingPoint>(
13441344
_ y: Tensor<T>
13451345
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
13461346
let value = max(x, y)
1347-
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
1347+
return (value, { v in
1348+
_vjpMinMaxHelper(x, y, originalValue: value, seed: v, comparisonOperation: .>=)
1349+
})
13481350
}
13491351

13501352
/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
@@ -1375,7 +1377,9 @@ internal func _vjpMin<T: TensorFlowFloatingPoint>(
13751377
_ y: Tensor<T>
13761378
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
13771379
let value = min(x, y)
1378-
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
1380+
return (value, { v in
1381+
_vjpMinMaxHelper(x, y, originalValue: value, seed: v, comparisonOperation: .<=)
1382+
})
13791383
}
13801384

13811385
/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
@@ -1397,11 +1401,12 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
13971401
_ x: Tensor<T>,
13981402
_ y: Tensor<T>,
13991403
originalValue: Tensor<T>,
1400-
seed: Tensor<T>
1404+
seed: Tensor<T>,
1405+
comparisonOperation: (Tensor<T>, Tensor<T>) -> Tensor<Bool>
14011406
) -> (Tensor<T>, Tensor<T>) {
1402-
let denominator = 1 + Tensor<T>(x .== y)
1403-
let lhsGrad = seed * Tensor<T>(x .== originalValue) / denominator
1404-
let rhsGrad = seed * Tensor<T>(y .== originalValue) / denominator
1407+
let mask = Tensor<T>(comparisonOperation(x, y))
1408+
let lhsGrad = seed * mask
1409+
let rhsGrad = seed * (1 - mask)
14051410
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
14061411
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
14071412
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),

Tests/TensorFlowTests/LossTests.swift

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,25 @@ final class LossTests: XCTestCase {
196196
assertEqual(loss, Tensor(expectedLoss), accuracy: 1e-6)
197197
}
198198

199-
func testSigmoidCrossEntropyGrad() {
200-
let logits = Tensor<Float>(
201-
shape: [2, 4],
202-
scalars: [-100, -2, -2, 0, 2, 2, 2, 100])
203-
204-
let labels = Tensor<Float>(
205-
shape: [2, 4],
206-
scalars: [0, 0, 1, 0, 0, 1, 0.5, 1])
199+
func testSigmoidCrossEntropyGradient() {
200+
let logits = Tensor<Float>(shape: [2, 4], scalars: [-100, -2, -2, 0, 0, 2, 2, 100])
201+
let labels = Tensor<Float>(shape: [2, 4], scalars: [0, 0, 1, 0, 1, 1, 0.5, 1])
207202

208-
// For each element x in logits and y in labels, the gradient is sigmoid(x) - y.
209-
let expectedGradientsBeforeMean = Tensor<Float>(
210-
shape: [2, 4],
211-
scalars: [0.00, 0.11920291, -0.8807971, 0.5,
212-
0.8807971, -0.11920291, 0.3807971 , 0.0])
213-
214-
// As the loss is mean loss, we should scale the golden gradient numbers.
215-
let expectedGradients = expectedGradientsBeforeMean / Float(logits.scalars.count)
216-
let gradients = gradient(
203+
let computedGradient = gradient(
217204
at: logits,
218205
in: { sigmoidCrossEntropy(logits: $0, labels: labels) })
219-
assertEqual(gradients, expectedGradients, accuracy: 1e-6)
206+
// The expected value of the gradient was computed using Python TensorFlow 1.14 with
207+
// the following code:
208+
// ```
209+
// with tf.GradientTape() as t:
210+
// t.watch([logits])
211+
// y = tf.losses.sigmoid_cross_entropy(labels, logits, reduction="weighted_mean")
212+
// print(t.gradient(y, [logits]))
213+
// ```
214+
let expectedGradient = Tensor<Float>([
215+
[0.0, 0.01490036, -0.11009964, 0.0625],
216+
[-0.0625, -0.01490036, 0.04759964, 0.0]])
217+
assertEqual(computedGradient, expectedGradient, accuracy: 1e-6)
220218
}
221219

222220
static var allTests = [
@@ -238,6 +236,6 @@ final class LossTests: XCTestCase {
238236
("testSoftmaxCrossEntropyWithProbabilitiesGrad",
239237
testSoftmaxCrossEntropyWithProbabilitiesGrad),
240238
("testSigmoidCrossEntropyLoss", testSigmoidCrossEntropyLoss),
241-
("testSigmoidCrossEntropyGrad", testSigmoidCrossEntropyGrad),
239+
("testSigmoidCrossEntropyGradient", testSigmoidCrossEntropyGradient),
242240
]
243241
}

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,77 @@ final class TensorAutoDiffTests: XCTestCase {
153153
XCTAssertEqual(varianceGradAlongAxes(input), expected)
154154
}
155155

156+
func testMin() {
157+
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
158+
// Python code with respective `a` and `b` tensors:
159+
// ```
160+
// with tf.GradientTape() as t:
161+
// t.watch([a, b])
162+
// y = tf.math.reduce_sum(tf.minimum(a, b))
163+
// print(t.gradient(y, [a, b]))
164+
// ```
165+
do {
166+
let a = Tensor<Float>([4, 5, 3])
167+
let b = Tensor<Float>([4, 2, 6])
168+
let computedGradient1 = gradient(at: a, b) { a, b in min(a, b).sum() }
169+
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
170+
[1.0, 0.0, 1.0], [0.0, 1.0, 0.0])
171+
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
172+
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)
173+
174+
let computedGradient2 = gradient(at: a, b) { a, b in min(b, a).sum() }
175+
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
176+
[0.0, 0.0, 1.0], [1.0, 1.0, 0.0])
177+
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
178+
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
179+
}
180+
181+
do {
182+
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
183+
let b = Tensor<Float>([9.0, -3.0])
184+
let computedGradient = gradient(at: a, b) { a, b in min(a, b).sum() }
185+
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
186+
[[1.0, 0.0], [1.0, 0.0]], [0.0, 2.0])
187+
XCTAssertEqual(computedGradient.0, expectedGradient.0)
188+
XCTAssertEqual(computedGradient.1, expectedGradient.1)
189+
}
190+
}
191+
192+
func testMax() {
193+
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
194+
// Python code with respective `a` and `b` tensors:
195+
// ```
196+
// with tf.GradientTape() as t:
197+
// t.watch([a, b])
198+
// y = tf.math.reduce_sum(tf.maximum(a, b))
199+
// print(t.gradient(y, [a, b]))
200+
// ```
201+
do {
202+
let a = Tensor<Float>([4, 5, 3])
203+
let b = Tensor<Float>([4, 2, 6])
204+
let computedGradient1 = gradient(at: a, b) { a, b in max(a, b).sum() }
205+
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
206+
[1.0, 1.0, 0.0], [0.0, 0.0, 1.0])
207+
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
208+
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)
209+
210+
let computedGradient2 = gradient(at: a, b) { a, b in max(b, a).sum() }
211+
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
212+
[0.0, 1.0, 0.0], [1.0, 0.0, 1.0])
213+
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
214+
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
215+
}
216+
do {
217+
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
218+
let b = Tensor<Float>([9.0, -3.0])
219+
let computedGradient = gradient(at: a, b) { a, b in max(a, b).sum() }
220+
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
221+
[[0.0, 1.0], [0.0, 1.0]], [2.0, 0.0])
222+
XCTAssertEqual(computedGradient.0, expectedGradient.0)
223+
XCTAssertEqual(computedGradient.1, expectedGradient.1)
224+
}
225+
}
226+
156227
/*TODO:(https://bugs.swift.org/browse/TF-771): Disabling this case as assertions fail.
157228
func testTensorInitStacking() {
158229
let a1 = Tensor<Float>([1, 2, 3, 4, 5])
@@ -449,6 +520,8 @@ final class TensorAutoDiffTests: XCTestCase {
449520
("testSum", testSum),
450521
("testMean", testMean),
451522
("testVariance", testVariance),
523+
("testMin", testMin),
524+
("testMax", testMax),
452525
// TODO(https://bugs.swift.org/browse/TF-771): Disabling the failing test.
453526
// ("testTensorInitStacking", testTensorInitStacking),
454527
("testExpandingShape", testExpandingShape),

0 commit comments

Comments
 (0)