Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Update min(_:_:) and max(_:_:) gradients to match Python TensorFlow #480

Merged
merged 10 commits into from
Aug 26, 2019
6 changes: 2 additions & 4 deletions Sources/TensorFlow/Loss.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,6 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
) -> Tensor<Scalar> {
// This numerically stable implementation is based on the TensorFlow Python API.
let maxLogitsWithZero = max(logits, Tensor(0))
// Note: `result` is split into two lines to avoid the "compiler is unable to type-check this
// expression in reasonable time" error.
let result = log(1 + exp(-abs(logits)))
return reduction(maxLogitsWithZero - logits * labels + result)
let negAbsLogits = max(logits, -logits) // Custom `abs` to compute gradients at `0`.
return reduction(maxLogitsWithZero - logits * labels + log1p(exp(-negAbsLogits)))
}
17 changes: 11 additions & 6 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,9 @@ internal func _vjpMax<T: TensorFlowFloatingPoint>(
_ y: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
let value = max(x, y)
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
return (value, { v in
_vjpMinMaxHelper(x, y, originalValue: value, seed: v, comparisonOperation: .>=)
})
}

/// Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.
Expand Down Expand Up @@ -1375,7 +1377,9 @@ internal func _vjpMin<T: TensorFlowFloatingPoint>(
_ y: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
let value = min(x, y)
return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, seed: v) })
return (value, { v in
_vjpMinMaxHelper(x, y, originalValue: value, seed: v, comparisonOperation: .<=)
})
}

/// Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.
Expand All @@ -1397,11 +1401,12 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
_ x: Tensor<T>,
_ y: Tensor<T>,
originalValue: Tensor<T>,
seed: Tensor<T>
seed: Tensor<T>,
comparisonOperation: (Tensor<T>, Tensor<T>) -> Tensor<Bool>
) -> (Tensor<T>, Tensor<T>) {
let denominator = 1 + Tensor<T>(x .== y)
let lhsGrad = seed * Tensor<T>(x .== originalValue) / denominator
let rhsGrad = seed * Tensor<T>(y .== originalValue) / denominator
let mask = Tensor<T>(comparisonOperation(x, y))
let lhsGrad = seed * mask
let rhsGrad = seed * (1 - mask)
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
Expand Down
36 changes: 17 additions & 19 deletions Tests/TensorFlowTests/LossTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,25 @@ final class LossTests: XCTestCase {
assertEqual(loss, Tensor(expectedLoss), accuracy: 1e-6)
}

func testSigmoidCrossEntropyGrad() {
let logits = Tensor<Float>(
shape: [2, 4],
scalars: [-100, -2, -2, 0, 2, 2, 2, 100])

let labels = Tensor<Float>(
shape: [2, 4],
scalars: [0, 0, 1, 0, 0, 1, 0.5, 1])
func testSigmoidCrossEntropyGradient() {
let logits = Tensor<Float>(shape: [2, 4], scalars: [-100, -2, -2, 0, 0, 2, 2, 100])
let labels = Tensor<Float>(shape: [2, 4], scalars: [0, 0, 1, 0, 1, 1, 0.5, 1])

// For each element x in logits and y in labels, the gradient is sigmoid(x) - y.
let expectedGradientsBeforeMean = Tensor<Float>(
shape: [2, 4],
scalars: [0.00, 0.11920291, -0.8807971, 0.5,
0.8807971, -0.11920291, 0.3807971 , 0.0])

// As the loss is mean loss, we should scale the golden gradient numbers.
let expectedGradients = expectedGradientsBeforeMean / Float(logits.scalars.count)
let gradients = gradient(
let computedGradient = gradient(
at: logits,
in: { sigmoidCrossEntropy(logits: $0, labels: labels) })
assertEqual(gradients, expectedGradients, accuracy: 1e-6)
// The expected value of the gradient was computed using Python TensorFlow 1.14 with
// the following code:
// ```
// with tf.GradientTape() as t:
// t.watch([logits])
// y = tf.losses.sigmoid_cross_entropy(labels, logits, reduction="weighted_mean")
// print(t.gradient(y, [logits]))
// ```
let expectedGradient = Tensor<Float>([
[0.0, 0.01490036, -0.11009964, 0.0625],
[-0.0625, -0.01490036, 0.04759964, 0.0]])
assertEqual(computedGradient, expectedGradient, accuracy: 1e-6)
}

static var allTests = [
Expand All @@ -238,6 +236,6 @@ final class LossTests: XCTestCase {
("testSoftmaxCrossEntropyWithProbabilitiesGrad",
testSoftmaxCrossEntropyWithProbabilitiesGrad),
("testSigmoidCrossEntropyLoss", testSigmoidCrossEntropyLoss),
("testSigmoidCrossEntropyGrad", testSigmoidCrossEntropyGrad),
("testSigmoidCrossEntropyGradient", testSigmoidCrossEntropyGradient),
]
}
73 changes: 73 additions & 0 deletions Tests/TensorFlowTests/TensorAutoDiffTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,77 @@ final class TensorAutoDiffTests: XCTestCase {
XCTAssertEqual(varianceGradAlongAxes(input), expected)
}

func testMin() {
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
// Python code with respective `a` and `b` tensors:
// ```
// with tf.GradientTape() as t:
// t.watch([a, b])
// y = tf.math.reduce_sum(tf.minimum(a, b))
// print(t.gradient(y, [a, b]))
// ```
do {
let a = Tensor<Float>([4, 5, 3])
let b = Tensor<Float>([4, 2, 6])
let computedGradient1 = gradient(at: a, b) { a, b in min(a, b).sum() }
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
[1.0, 0.0, 1.0], [0.0, 1.0, 0.0])
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)

let computedGradient2 = gradient(at: a, b) { a, b in min(b, a).sum() }
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
[0.0, 0.0, 1.0], [1.0, 1.0, 0.0])
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
}

do {
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
let b = Tensor<Float>([9.0, -3.0])
let computedGradient = gradient(at: a, b) { a, b in min(a, b).sum() }
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
[[1.0, 0.0], [1.0, 0.0]], [0.0, 2.0])
XCTAssertEqual(computedGradient.0, expectedGradient.0)
XCTAssertEqual(computedGradient.1, expectedGradient.1)
}
}

func testMax() {
// The expected gradient values were computed using the following TensorFlow 2.0 Beta1
// Python code with respective `a` and `b` tensors:
// ```
// with tf.GradientTape() as t:
// t.watch([a, b])
// y = tf.math.reduce_sum(tf.maximum(a, b))
// print(t.gradient(y, [a, b]))
// ```
do {
let a = Tensor<Float>([4, 5, 3])
let b = Tensor<Float>([4, 2, 6])
let computedGradient1 = gradient(at: a, b) { a, b in max(a, b).sum() }
let expectedGradient1: (Tensor<Float>, Tensor<Float>) = (
[1.0, 1.0, 0.0], [0.0, 0.0, 1.0])
XCTAssertEqual(computedGradient1.0, expectedGradient1.0)
XCTAssertEqual(computedGradient1.1, expectedGradient1.1)

let computedGradient2 = gradient(at: a, b) { a, b in max(b, a).sum() }
let expectedGradient2: (Tensor<Float>, Tensor<Float>) = (
[0.0, 1.0, 0.0], [1.0, 0.0, 1.0])
XCTAssertEqual(computedGradient2.0, expectedGradient2.0)
XCTAssertEqual(computedGradient2.1, expectedGradient2.1)
}
do {
let a = Tensor<Float>([[3.0, -2.0], [0.3, 10.0]])
let b = Tensor<Float>([9.0, -3.0])
let computedGradient = gradient(at: a, b) { a, b in max(a, b).sum() }
let expectedGradient: (Tensor<Float>, Tensor<Float>) = (
[[0.0, 1.0], [0.0, 1.0]], [2.0, 0.0])
XCTAssertEqual(computedGradient.0, expectedGradient.0)
XCTAssertEqual(computedGradient.1, expectedGradient.1)
}
}

/*TODO:(https://bugs.swift.org/browse/TF-771): Disabling this case as assertions fail.
func testTensorInitStacking() {
let a1 = Tensor<Float>([1, 2, 3, 4, 5])
Expand Down Expand Up @@ -449,6 +520,8 @@ final class TensorAutoDiffTests: XCTestCase {
("testSum", testSum),
("testMean", testMean),
("testVariance", testVariance),
("testMin", testMin),
("testMax", testMax),
// TODO(https://bugs.swift.org/browse/TF-771): Disabling the failing test.
// ("testTensorInitStacking", testTensorInitStacking),
("testExpandingShape", testExpandingShape),
Expand Down