Skip to content

Commit 6deffd0

Browse files
authored
TF-404: Add .standardDeviation() to Tensor. (#23763)
This change adds `standardDeviation()` to compute the standard deviation of a Tensor. It supports optionally specifying the `alongAxes` parameter, which allows users to reduce only along specified axes.
1 parent bbe43ed commit 6deffd0

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

stdlib/public/TensorFlow/Ops.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,45 @@ public extension Tensor where Scalar : FloatingPoint & Equatable {
577577
}
578578
}
579579

580+
public extension Tensor where Scalar : TensorFlowFloatingPoint {
581+
// TODO: standardDeviation() should handle non floating point Tensors.
582+
583+
/// Returns the standard deviation of the elements along the specified axes.
584+
/// The reduced dimensions are retained with value `1`. Does not apply
585+
/// Bessel's correction.
586+
///
587+
/// - Parameter axes: The dimensions to reduce.
588+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
589+
@differentiable(wrt: self)
590+
func standardDeviation() -> Tensor {
591+
// Reduce along all dimensions.
592+
return standardDeviation(alongAxes: Array(0..<shape.rank))
593+
}
594+
595+
/// Returns the standard deviation of the elements along the specified axes.
596+
/// The reduced dimensions are retained with value `1`. Does not apply
597+
/// Bessel's correction.
598+
///
599+
/// - Parameter axes: The dimensions to reduce.
600+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
601+
@differentiable(wrt: self)
602+
func standardDeviation(alongAxes axes: Int32...) -> Tensor {
603+
return standardDeviation(alongAxes: axes)
604+
}
605+
606+
/// Returns the standard deviation of the elements along the specified axes.
607+
/// The reduced dimensions are retained with value `1`. Does not apply
608+
/// Bessel's correction.
609+
///
610+
/// - Parameter axes: The dimensions to reduce.
611+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
612+
@inlinable @inline(__always)
613+
@differentiable(wrt: self)
614+
func standardDeviation(alongAxes axes: [Int32]) -> Tensor {
615+
return sqrt(variance(alongAxes: axes))
616+
}
617+
}
618+
580619
public extension Tensor where Scalar == Bool {
581620
/// Computes `!self` element-wise.
582621
@inlinable @inline(__always)

test/TensorFlowRuntime/tensor.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,26 @@ TensorTests.testAllBackends("SimpleMath") {
276276
byError: 0.0001)
277277
}
278278

279+
TensorTests.testAllBackends("StandardDeviation") {
280+
expectEqual(0, Tensor<Float>([1]).standardDeviation().scalarized())
281+
expectEqual(
282+
0.5,
283+
Tensor<Float>([0, 1]).standardDeviation(alongAxes: 0).scalarized())
284+
expectEqual(0.5, Tensor<Float>([0, 1]).standardDeviation().scalarized())
285+
expectNearlyEqual(
286+
2.87228132,
287+
Tensor<Float>(rangeFrom: 0, to: 10, stride: 1).standardDeviation().scalarized(),
288+
byError: 0.001)
289+
let matrix = Tensor<Float>(rangeFrom: 0, to: 10, stride: 1).reshaped(to: [2, 5])
290+
expectNearlyEqual(2.87228132,
291+
matrix.standardDeviation().scalarized(),
292+
byError: 0.001)
293+
expectPointwiseNearlyEqual(
294+
[1.4142, 1.4142],
295+
matrix.standardDeviation(alongAxes: 1).array.scalars,
296+
byError: 0.001)
297+
}
298+
279299
TensorTests.testAllBackends("ReductionToScalar") {
280300
let _: Tensor<Float> = [1, 2, 3, 4, 5]
281301
// expectEqual(x.mean(), 3)

0 commit comments

Comments
 (0)