@@ -2371,18 +2371,40 @@ public struct Moments<Scalar: TensorFlowFloatingPoint>: Differentiable {
2371
2371
public extension Tensor where Scalar: TensorFlowFloatingPoint {
2372
2372
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2373
2373
/// dimensions are removed.
2374
+ ///
2375
+ /// - Parameter axes: The dimensions to reduce.
2376
+ /// - Precondition: `axes` must have rank `1`.
2377
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2374
2378
@inlinable
2375
2379
@differentiable ( wrt: self )
2376
- func moments( squeezingAxes axes: [ Int ] ) -> Moments < Scalar > {
2380
+ func moments( squeezingAxes axes: Tensor < Int32 > ) -> Moments < Scalar > {
2377
2381
let mean = self . mean ( alongAxes: axes)
2378
- let variance = squaredDifference ( self , mean) . mean ( alongAxes: axes)
2379
- return Moments < Scalar > (
2380
- mean: mean. squeezingShape ( at: axes) ,
2381
- variance: variance. squeezingShape ( at: axes) )
2382
+ let variance = squaredDifference ( self , mean) . mean ( squeezingAxes: axes)
2383
+ return Moments (
2384
+ // The following is required because `Tensor.squeezingShape(at:)` does not accept
2385
+ // `Tensor<Int32>`-valued arguments.
2386
+ mean: mean. sum ( squeezingAxes: axes) ,
2387
+ variance: variance)
2388
+ }
2389
+
2390
+ /// Returns the mean and variance of this tensor along the specified axes. The reduced
2391
+ /// dimensions are removed.
2392
+ ///
2393
+ /// - Parameter axes: The dimensions to reduce.
2394
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2395
+ @inlinable
2396
+ @differentiable ( wrt: self )
2397
+ func moments( squeezingAxes axes: [ Int ] ) -> Moments < Scalar > {
2398
+ // TODO(TF-433): Remove workaround for differentiating `map`.
2399
+ let axes = { axes. map ( Int32 . init) } ( )
2400
+ return moments ( squeezingAxes: Tensor < Int32 > ( axes) )
2382
2401
}
2383
2402
2384
2403
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2385
2404
/// dimensions are removed.
2405
+ ///
2406
+ /// - Parameter axes: The dimensions to reduce.
2407
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2386
2408
@inlinable
2387
2409
@differentiable ( wrt: self )
2388
2410
func moments( squeezingAxes axes: Int ... ) -> Moments < Scalar > {
@@ -2398,16 +2420,36 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
2398
2420
2399
2421
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2400
2422
/// dimensions are retained with value `1`.
2423
+ ///
2424
+ /// - Parameter axes: The dimensions to reduce.
2425
+ /// - Precondition: `axes` must have rank `1`.
2426
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2401
2427
@inlinable
2402
2428
@differentiable ( wrt: self )
2403
- func moments( alongAxes axes: [ Int ] ) -> Moments < Scalar > {
2429
+ func moments( alongAxes axes: Tensor < Int32 > ) -> Moments < Scalar > {
2404
2430
let mean = self . mean ( alongAxes: axes)
2405
2431
let variance = squaredDifference ( self , mean) . mean ( alongAxes: axes)
2406
2432
return Moments < Scalar > ( mean: mean, variance: variance)
2407
2433
}
2408
2434
2409
2435
/// Returns the mean and variance of this tensor along the specified axes. The reduced
2410
2436
/// dimensions are retained with value `1`.
2437
+ ///
2438
+ /// - Parameter axes: The dimensions to reduce.
2439
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2440
+ @inlinable
2441
+ @differentiable ( wrt: self )
2442
+ func moments( alongAxes axes: [ Int ] ) -> Moments < Scalar > {
2443
+ // TODO(TF-433): Remove workaround for differentiating `map`.
2444
+ let axes = { axes. map ( Int32 . init) } ( )
2445
+ return moments ( alongAxes: Tensor < Int32 > ( axes) )
2446
+ }
2447
+
2448
+ /// Returns the mean and variance of this tensor along the specified axes. The reduced
2449
+ /// dimensions are retained with value `1`.
2450
+ ///
2451
+ /// - Parameter axes: The dimensions to reduce.
2452
+ /// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2411
2453
@inlinable
2412
2454
@differentiable ( wrt: self )
2413
2455
func moments( alongAxes axes: Int ... ) -> Moments < Scalar > {
0 commit comments