Skip to content

Commit 57e295b

Browse files
committed
[TF API] Mark some more reduction methods @differentiable where
1 parent 6cc95e3 commit 57e295b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

stdlib/public/TensorFlow/Ops.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,7 @@ public extension Tensor where Scalar : Numeric {
12941294
/// - Parameter axes: The dimensions to reduce.
12951295
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
12961296
@inlinable @inline(__always)
1297+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
12971298
func mean(alongAxes axes: Int32...) -> Tensor {
12981299
return mean(alongAxes: axes)
12991300
}
@@ -1316,6 +1317,7 @@ public extension Tensor where Scalar : Numeric {
13161317
/// - Parameter axes: The dimensions to reduce.
13171318
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13181319
@inlinable @inline(__always)
1320+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
13191321
func sum(alongAxes axes: Int32...) -> Tensor {
13201322
return sum(alongAxes: axes)
13211323
}
@@ -1325,6 +1327,7 @@ public extension Tensor where Scalar : Numeric {
13251327
/// - Parameter axes: The dimensions to reduce.
13261328
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13271329
@inlinable @inline(__always)
1330+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
13281331
func variance(alongAxes axes: Int32...) -> Tensor {
13291332
return variance(alongAxes: axes)
13301333
}
@@ -1334,6 +1337,7 @@ public extension Tensor where Scalar : Numeric {
13341337
/// - Parameter axes: The dimensions to reduce.
13351338
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13361339
@inlinable @inline(__always)
1340+
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
13371341
func variance(alongAxes axes: [Int32]) -> Tensor {
13381342
let mean = self.mean(alongAxes: axes)
13391343
let squaredDiff = (self - mean).squared()

0 commit comments

Comments
 (0)