Skip to content

Commit 7cae13b

Browse files
authored
Add mean and variance ops that take a Tensor<Int32> for axes. (#23864)
1 parent b71fb98 commit 7cae13b

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -581,11 +581,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
581581
}
582582

583583
@inlinable
584-
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
584+
func _vjpMean(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
585585
let value = mean(alongAxes: axes)
586-
return (value, { [shape = shapeTensor,
587-
count = axes.map { shape[$0] }.reduce(1, *)] in
588-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
586+
let count = Raw.gather(params: shapeTensor, indices: axes).product()
587+
return (value, { [shape = shapeTensor] in
588+
$0.broadcast(toShape: shape) / Tensor(count)
589589
})
590590
}
591591

stdlib/public/TensorFlow/Ops.swift

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,8 +1348,18 @@ public extension Tensor where Scalar : Numeric {
13481348
wrt: self, vjp: _vjpMean(alongAxes:)
13491349
where Scalar : TensorFlowFloatingPoint
13501350
)
1351+
func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
1352+
return Raw.mean(self, reductionIndices: axes, keepDims: true)
1353+
}
1354+
1355+
/// Returns the arithmetic mean along the specified axes. The reduced
1356+
/// dimensions are retained with value 1.
1357+
/// - Parameter axes: The dimensions to reduce.
1358+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1359+
@inlinable @inline(__always)
1360+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
13511361
func mean(alongAxes axes: [Int32]) -> Tensor {
1352-
return Raw.mean(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
1362+
return mean(alongAxes: Tensor<Int32>(axes))
13531363
}
13541364

13551365
/// Returns the arithmetic mean along the specified axes. The reduced
@@ -1401,12 +1411,22 @@ public extension Tensor where Scalar : Numeric {
14011411
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
14021412
@inlinable @inline(__always)
14031413
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1404-
func variance(alongAxes axes: [Int32]) -> Tensor {
1414+
func variance(alongAxes axes: Tensor<Int32>) -> Tensor {
14051415
let mean = self.mean(alongAxes: axes)
14061416
let squaredDiff = (self - mean).squared()
14071417
return squaredDiff.mean(alongAxes: axes)
14081418
}
14091419

1420+
/// Returns the variance along the specified axes. The reduced dimensions are
1421+
/// retained with value 1. Does not apply Bessel's correction.
1422+
/// - Parameter axes: The dimensions to reduce.
1423+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1424+
@inlinable @inline(__always)
1425+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1426+
func variance(alongAxes axes: [Int32]) -> Tensor {
1427+
return variance(alongAxes: Tensor<Int32>(axes))
1428+
}
1429+
14101430
/// Returns the product along the specified axes. The reduced dimensions are
14111431
/// retained with value 1.
14121432
/// - Parameter axes: The dimensions to reduce.

0 commit comments

Comments
 (0)