Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8d0b1d8

Browse files
authored
Added a couple missing 'Tensor.moments' functions. (#363)
1 parent 33b3b5b commit 8d0b1d8

File tree

1 file changed

+48
-6
lines changed

1 file changed

+48
-6
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2371,18 +2371,40 @@ public struct Moments<Scalar: TensorFlowFloatingPoint>: Differentiable {
23712371
public extension Tensor where Scalar: TensorFlowFloatingPoint {
23722372
/// Returns the mean and variance of this tensor along the specified axes. The reduced
23732373
/// dimensions are removed.
2374+
///
2375+
/// - Parameter axes: The dimensions to reduce.
2376+
/// - Precondition: `axes` must have rank `1`.
2377+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
23742378
@inlinable
23752379
@differentiable(wrt: self)
2376-
func moments(squeezingAxes axes: [Int]) -> Moments<Scalar> {
2380+
func moments(squeezingAxes axes: Tensor<Int32>) -> Moments<Scalar> {
23772381
let mean = self.mean(alongAxes: axes)
2378-
let variance = squaredDifference(self, mean).mean(alongAxes: axes)
2379-
return Moments<Scalar>(
2380-
mean: mean.squeezingShape(at: axes),
2381-
variance: variance.squeezingShape(at: axes))
2382+
let variance = squaredDifference(self, mean).mean(squeezingAxes: axes)
2383+
return Moments(
2384+
// The following is required because `Tensor.squeezingShape(at:)` does not accept
2385+
// `Tensor<Int32>`-valued arguments.
2386+
mean: mean.sum(squeezingAxes: axes),
2387+
variance: variance)
2388+
}
2389+
2390+
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2391+
/// dimensions are removed.
2392+
///
2393+
/// - Parameter axes: The dimensions to reduce.
2394+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2395+
@inlinable
2396+
@differentiable(wrt: self)
2397+
func moments(squeezingAxes axes: [Int]) -> Moments<Scalar> {
2398+
// TODO(TF-433): Remove workaround for differentiating `map`.
2399+
let axes = {axes.map(Int32.init)}()
2400+
return moments(squeezingAxes: Tensor<Int32>(axes))
23822401
}
23832402

23842403
/// Returns the mean and variance of this tensor along the specified axes. The reduced
23852404
/// dimensions are removed.
2405+
///
2406+
/// - Parameter axes: The dimensions to reduce.
2407+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
23862408
@inlinable
23872409
@differentiable(wrt: self)
23882410
func moments(squeezingAxes axes: Int...) -> Moments<Scalar> {
@@ -2398,16 +2420,36 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
23982420

23992421
/// Returns the mean and variance of this tensor along the specified axes. The reduced
24002422
/// dimensions are retained with value `1`.
2423+
///
2424+
/// - Parameter axes: The dimensions to reduce.
2425+
/// - Precondition: `axes` must have rank `1`.
2426+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
24012427
@inlinable
24022428
@differentiable(wrt: self)
2403-
func moments(alongAxes axes: [Int]) -> Moments<Scalar> {
2429+
func moments(alongAxes axes: Tensor<Int32>) -> Moments<Scalar> {
24042430
let mean = self.mean(alongAxes: axes)
24052431
let variance = squaredDifference(self, mean).mean(alongAxes: axes)
24062432
return Moments<Scalar>(mean: mean, variance: variance)
24072433
}
24082434

24092435
/// Returns the mean and variance of this tensor along the specified axes. The reduced
24102436
/// dimensions are retained with value `1`.
2437+
///
2438+
/// - Parameter axes: The dimensions to reduce.
2439+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2440+
@inlinable
2441+
@differentiable(wrt: self)
2442+
func moments(alongAxes axes: [Int]) -> Moments<Scalar> {
2443+
// TODO(TF-433): Remove workaround for differentiating `map`.
2444+
let axes = {axes.map(Int32.init)}()
2445+
return moments(alongAxes: Tensor<Int32>(axes))
2446+
}
2447+
2448+
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2449+
/// dimensions are retained with value `1`.
2450+
///
2451+
/// - Parameter axes: The dimensions to reduce.
2452+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
24112453
@inlinable
24122454
@differentiable(wrt: self)
24132455
func moments(alongAxes axes: Int...) -> Moments<Scalar> {

0 commit comments

Comments
 (0)