Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c1b18f7

Browse files
eaplataniosrxwei
authored andcommitted
Minor tweak to allow min and max to be differentiable. (#159)
1 parent 0e3b3f4 commit c1b18f7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
930930
// NOTE: This overload is necessary, otherwise `min()` would refer to the variadic method
931931
// `min(squeezingAxes:)` with zero indices.
932932
@inlinable
933+
@differentiable(where Scalar: TensorFlowFloatingPoint)
933934
func min() -> Tensor {
934935
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
935936
return min(squeezingAxes: axes)
@@ -938,6 +939,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
938939
// NOTE: This overload is necessary, otherwise `max()` would refer to the variadic method
939940
// `max(squeezingAxes:)` with zero indices.
940941
@inlinable
942+
@differentiable(where Scalar: TensorFlowFloatingPoint)
941943
func max() -> Tensor {
942944
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(rank), stride: 1)
943945
return max(squeezingAxes: axes)
@@ -958,15 +960,18 @@ public extension Tensor where Scalar: Numeric & Comparable {
958960
/// - Parameter axes: The dimensions to reduce.
959961
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
960962
@inlinable
963+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
961964
func max(squeezingAxes axes: [Int]) -> Tensor {
962-
let axes = axes.map(Int32.init)
965+
// TODO(TF-433): Remove workaround for differentiating `map`.
966+
let axes = {axes.map(Int32.init)}()
963967
return max(squeezingAxes: Tensor<Int32>(axes))
964968
}
965969

966970
/// Returns the maximum values along the specified axes. The reduced dimensions are removed.
967971
/// - Parameter axes: The dimensions to reduce.
968972
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
969973
@inlinable
974+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
970975
func max(squeezingAxes axes: Int...) -> Tensor {
971976
return max(squeezingAxes: axes)
972977
}
@@ -986,15 +991,18 @@ public extension Tensor where Scalar: Numeric & Comparable {
986991
/// - Parameter axes: The dimensions to reduce.
987992
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
988993
@inlinable
994+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
989995
func min(squeezingAxes axes: [Int]) -> Tensor {
990-
let axes = axes.map(Int32.init)
996+
// TODO(TF-433): Remove workaround for differentiating `map`.
997+
let axes = {axes.map(Int32.init)}()
991998
return min(squeezingAxes: Tensor<Int32>(axes))
992999
}
9931000

9941001
/// Returns the minimum values along the specified axes. The reduced dimensions are removed.
9951002
/// - Parameter axes: The dimensions to reduce.
9961003
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
9971004
@inlinable
1005+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
9981006
func min(squeezingAxes axes: Int...) -> Tensor {
9991007
return min(squeezingAxes: axes)
10001008
}

0 commit comments

Comments
 (0)