Skip to content

Commit ec50086

Browse files
authored
Add variance() and variance(squeezingAxes:). (#23811)
* Add `variance()` and `variance(squeezingAxes:)`. * Reorganize reduction ops: `sum`, `product`, `mean`, `variance`.
1 parent 0f9f819 commit ec50086

File tree

4 files changed

+170
-68
lines changed

4 files changed

+170
-68
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
568568
//===----------------------------------------------------------------------===//
569569

570570
extension Tensor where Scalar : TensorFlowFloatingPoint {
571+
@inlinable
572+
func _vjpSum() -> (Tensor, (Tensor) -> Tensor) {
573+
return (sum(), { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
574+
}
575+
571576
@inlinable
572577
func _vjpMean() -> (Tensor, (Tensor) -> Tensor) {
573578
return (mean(), { [shape = shapeTensor, count = scalarCountTensor] in
@@ -576,8 +581,15 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
576581
}
577582

578583
@inlinable
579-
func _vjpSum() -> (Tensor, (Tensor) -> Tensor) {
580-
return (sum(), { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
584+
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
585+
let value = sum(alongAxes: axes)
586+
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
587+
}
588+
589+
@inlinable
590+
func _vjpSum(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
591+
let value = sum(squeezingAxes: axes)
592+
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
581593
}
582594

583595
@inlinable
@@ -590,9 +602,21 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
590602
}
591603

592604
@inlinable
593-
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
594-
let value = sum(alongAxes: axes)
595-
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
605+
func _vjpMean(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
606+
let value = mean(squeezingAxes: axes)
607+
return (value, { [shape = shapeTensor,
608+
count = axes.map { shape[$0] }.reduce(1, *)] in
609+
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
610+
})
611+
}
612+
613+
@inlinable
614+
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
615+
let value = mean(alongAxes: axes)
616+
return (value, { [shape = shapeTensor,
617+
count = axes.map { shape[$0] }.reduce(1, *)] in
618+
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
619+
})
596620
}
597621
}
598622

stdlib/public/TensorFlow/Ops.swift

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,18 +1251,6 @@ public extension Tensor where Scalar : Numeric & Comparable {
12511251
}
12521252

12531253
public extension Tensor where Scalar : Numeric {
1254-
// NOTE: This overload is necessary, otherwise `mean()` would refer
1255-
// to the variadic method `mean(squeezingAxes:)` with zero indices.
1256-
@differentiable(
1257-
wrt: self, vjp: _vjpMean()
1258-
where Scalar : TensorFlowFloatingPoint
1259-
)
1260-
@inlinable @inline(__always)
1261-
func mean() -> Tensor {
1262-
let axes = Tensor<Int32>(rangeFrom: 0, to: rank, stride: 1)
1263-
return Raw.mean(self, reductionIndices: axes)
1264-
}
1265-
12661254
// NOTE: This overload is necessary, otherwise `sum()` would refer
12671255
// to the variadic method `sum(squeezingAxes:)` with zero indices.
12681256
@inlinable @inline(__always)
@@ -1283,30 +1271,37 @@ public extension Tensor where Scalar : Numeric {
12831271
return Raw.prod(self, reductionIndices: axes)
12841272
}
12851273

1286-
/// Returns the arithmetic mean along the specified axes. The reduced
1287-
/// dimensions are removed.
1288-
/// - Parameter axes: The dimensions to reduce.
1289-
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1274+
// NOTE: This overload is necessary, otherwise `mean()` would refer
1275+
// to the variadic method `mean(squeezingAxes:)` with zero indices.
1276+
@differentiable(
1277+
wrt: self, vjp: _vjpMean()
1278+
where Scalar : TensorFlowFloatingPoint
1279+
)
12901280
@inlinable @inline(__always)
1291-
func mean(squeezingAxes axes: [Int32]) -> Tensor {
1292-
return Raw.mean(self, reductionIndices: Tensor<Int32>(axes),
1293-
keepDims: false)
1281+
func mean() -> Tensor {
1282+
let axes = Tensor<Int32>(rangeFrom: 0, to: rank, stride: 1)
1283+
return Raw.mean(self, reductionIndices: axes)
12941284
}
12951285

1296-
/// Returns the arithmetic mean along the specified axes. The reduced
1297-
/// dimensions are removed.
1298-
/// - Parameter axes: The dimensions to reduce.
1299-
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1286+
// NOTE: This overload is necessary, otherwise `mean()` would refer
1287+
// to the variadic method `mean(squeezingAxes:)` with zero indices.
1288+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
13001289
@inlinable @inline(__always)
1301-
func mean(squeezingAxes axes: Int32...) -> Tensor {
1302-
return mean(squeezingAxes: axes)
1290+
func variance() -> Tensor {
1291+
let mean = self.mean()
1292+
let squaredDiff = (self - mean).squared()
1293+
return squaredDiff.mean()
13031294
}
13041295

13051296
/// Returns the sum along the specified axes. The reduced dimensions are
13061297
/// removed.
13071298
/// - Parameter axes: The dimensions to reduce.
13081299
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
13091300
@inlinable @inline(__always)
1301+
@differentiable(
1302+
wrt: self, vjp: _vjpSum(squeezingAxes:)
1303+
where Scalar : TensorFlowFloatingPoint
1304+
)
13101305
func sum(squeezingAxes axes: [Int32]) -> Tensor {
13111306
return Raw.sum(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
13121307
}
@@ -1340,36 +1335,48 @@ public extension Tensor where Scalar : Numeric {
13401335
}
13411336

13421337
/// Returns the arithmetic mean along the specified axes. The reduced
1343-
/// dimensions are retained with value 1.
1338+
/// dimensions are removed.
13441339
/// - Parameter axes: The dimensions to reduce.
1345-
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1340+
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
13461341
@inlinable @inline(__always)
13471342
@differentiable(
1348-
wrt: self, vjp: _vjpMean(alongAxes:)
1343+
wrt: self, vjp: _vjpMean(squeezingAxes:)
13491344
where Scalar : TensorFlowFloatingPoint
13501345
)
1351-
func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
1352-
return Raw.mean(self, reductionIndices: axes, keepDims: true)
1346+
func mean(squeezingAxes axes: [Int32]) -> Tensor {
1347+
return Raw.mean(self, reductionIndices: Tensor<Int32>(axes),
1348+
keepDims: false)
13531349
}
13541350

13551351
/// Returns the arithmetic mean along the specified axes. The reduced
1356-
/// dimensions are retained with value 1.
1352+
/// dimensions are removed.
1353+
/// - Parameter axes: The dimensions to reduce.
1354+
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
1355+
@inlinable @inline(__always)
1356+
func mean(squeezingAxes axes: Int32...) -> Tensor {
1357+
return mean(squeezingAxes: axes)
1358+
}
1359+
1360+
/// Returns the variance along the specified axes. The reduced dimensions are
1361+
/// retained with value 1. Does not apply Bessel's correction.
13571362
/// - Parameter axes: The dimensions to reduce.
13581363
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13591364
@inlinable @inline(__always)
13601365
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1361-
func mean(alongAxes axes: [Int32]) -> Tensor {
1362-
return mean(alongAxes: Tensor<Int32>(axes))
1366+
func variance(squeezingAxes axes: Int32...) -> Tensor {
1367+
return variance(squeezingAxes: axes)
13631368
}
13641369

1365-
/// Returns the arithmetic mean along the specified axes. The reduced
1366-
/// dimensions are retained with value 1.
1370+
/// Returns the variance along the specified axes. The reduced dimensions are
1371+
/// removed. Does not apply Bessel's correction.
13671372
/// - Parameter axes: The dimensions to reduce.
13681373
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
13691374
@inlinable @inline(__always)
13701375
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1371-
func mean(alongAxes axes: Int32...) -> Tensor {
1372-
return mean(alongAxes: axes)
1376+
func variance(squeezingAxes axes: [Int32]) -> Tensor {
1377+
let mean = self.mean(alongAxes: axes)
1378+
let squaredDiff = (self - mean).squared()
1379+
return squaredDiff.mean(squeezingAxes: axes)
13731380
}
13741381

13751382
/// Returns the sum along the specified axes. The reduced dimensions are
@@ -1395,6 +1402,60 @@ public extension Tensor where Scalar : Numeric {
13951402
return sum(alongAxes: axes)
13961403
}
13971404

1405+
/// Returns the product along the specified axes. The reduced dimensions are
1406+
/// retained with value 1.
1407+
/// - Parameter axes: The dimensions to reduce.
1408+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1409+
@inlinable @inline(__always)
1410+
func product(alongAxes axes: [Int32]) -> Tensor {
1411+
return Raw.prod(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
1412+
}
1413+
1414+
/// Returns the product along the specified axes. The reduced dimensions are
1415+
/// retained with value 1.
1416+
/// - Parameter axes: The dimensions to reduce.
1417+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1418+
@inlinable @inline(__always)
1419+
func product(alongAxes axes: Int32...) -> Tensor {
1420+
return product(alongAxes: axes)
1421+
}
1422+
1423+
/// Returns the arithmetic mean along the specified axes. The reduced
1424+
/// dimensions are retained with value 1.
1425+
/// - Parameter axes: The dimensions to reduce.
1426+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1427+
@inlinable @inline(__always)
1428+
@differentiable(
1429+
wrt: self, vjp: _vjpMean(alongAxes:)
1430+
where Scalar : TensorFlowFloatingPoint
1431+
)
1432+
func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
1433+
return Raw.mean(self, reductionIndices: axes, keepDims: true)
1434+
}
1435+
1436+
/// Returns the arithmetic mean along the specified axes. The reduced
1437+
/// dimensions are retained with value 1.
1438+
/// - Parameter axes: The dimensions to reduce.
1439+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1440+
@inlinable @inline(__always)
1441+
@differentiable(
1442+
wrt: self, vjp: _vjpMean(alongAxes:)
1443+
where Scalar : TensorFlowFloatingPoint
1444+
)
1445+
func mean(alongAxes axes: [Int32]) -> Tensor {
1446+
return mean(alongAxes: Tensor<Int32>(axes))
1447+
}
1448+
1449+
/// Returns the arithmetic mean along the specified axes. The reduced
1450+
/// dimensions are retained with value 1.
1451+
/// - Parameter axes: The dimensions to reduce.
1452+
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1453+
@inlinable @inline(__always)
1454+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1455+
func mean(alongAxes axes: Int32...) -> Tensor {
1456+
return mean(alongAxes: axes)
1457+
}
1458+
13981459
/// Returns the variance along the specified axes. The reduced dimensions are
13991460
/// retained with value 1. Does not apply Bessel's correction.
14001461
/// - Parameter axes: The dimensions to reduce.
@@ -1426,24 +1487,6 @@ public extension Tensor where Scalar : Numeric {
14261487
func variance(alongAxes axes: [Int32]) -> Tensor {
14271488
return variance(alongAxes: Tensor<Int32>(axes))
14281489
}
1429-
1430-
/// Returns the product along the specified axes. The reduced dimensions are
1431-
/// retained with value 1.
1432-
/// - Parameter axes: The dimensions to reduce.
1433-
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1434-
@inlinable @inline(__always)
1435-
func product(alongAxes axes: [Int32]) -> Tensor {
1436-
return Raw.prod(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
1437-
}
1438-
1439-
/// Returns the product along the specified axes. The reduced dimensions are
1440-
/// retained with value 1.
1441-
/// - Parameter axes: The dimensions to reduce.
1442-
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1443-
@inlinable @inline(__always)
1444-
func product(alongAxes axes: Int32...) -> Tensor {
1445-
return product(alongAxes: axes)
1446-
}
14471490
}
14481491

14491492
//===----------------------------------------------------------------------===//

test/TensorFlowRuntime/tensor.swift

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,37 @@ TensorTests.testAllBackends("Reduction") {
178178
#if !TPU
179179
// 2 x 5
180180
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
181-
expectEqual(ShapedArray(shape: [5], scalars: [2, 4, 6, 8, 10]),
182-
x.sum(squeezingAxes: 0).toHost(shape: []).array)
183-
expectEqual(ShapedArray(shape: [1, 5], scalars: [2, 4, 6, 8, 10]),
184-
x.sum(alongAxes: 0).toHost(shape: []).array)
185-
expectEqual(ShapedArray(shape: [5], scalars: [1, 4, 9, 16, 25]),
186-
x.product(squeezingAxes: 0).toHost(shape: []).array)
187-
expectEqual(ShapedArray(shape: [1, 5], scalars: [1, 4, 9, 16, 25]),
188-
x.product(alongAxes: 0).toHost(shape: []).array)
181+
expectEqual(Tensor(30), x.sum().toHost(shape: []))
182+
expectEqual(Tensor(shape: [5], scalars: [2, 4, 6, 8, 10]),
183+
x.sum(squeezingAxes: 0).toHost(shape: []))
184+
expectEqual(Tensor(shape: [1, 5], scalars: [2, 4, 6, 8, 10]),
185+
x.sum(alongAxes: 0).toHost(shape: []))
186+
187+
expectEqual(Tensor(14400), x.product().toHost(shape: []))
188+
expectEqual(Tensor(shape: [5], scalars: [1, 4, 9, 16, 25]),
189+
x.product(squeezingAxes: 0).toHost(shape: []))
190+
expectEqual(Tensor(shape: [1, 5], scalars: [1, 4, 9, 16, 25]),
191+
x.product(alongAxes: 0).toHost(shape: []))
192+
193+
expectEqual(Tensor(3), x.mean().toHost(shape: []))
194+
expectEqual(Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]),
195+
x.mean(squeezingAxes: 0).toHost(shape: []))
196+
expectEqual(Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]),
197+
x.mean(alongAxes: 0).toHost(shape: []))
198+
expectEqual(Tensor(shape: [2], scalars: [3, 3]),
199+
x.mean(squeezingAxes: 1).toHost(shape: []))
200+
expectEqual(Tensor(shape: [1, 2], scalars: [3, 3]),
201+
x.mean(alongAxes: 1).toHost(shape: []))
202+
203+
expectEqual(Tensor(2), x.variance().toHost(shape: []))
204+
expectEqual(Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]),
205+
x.variance(squeezingAxes: 0).toHost(shape: []))
206+
expectEqual(Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]),
207+
x.variance(alongAxes: 0).toHost(shape: []))
208+
expectEqual(Tensor(shape: [2], scalars: [2, 2]),
209+
x.variance(squeezingAxes: 1).toHost(shape: []))
210+
expectEqual(Tensor(shape: [1, 2], scalars: [2, 2]),
211+
x.variance(alongAxes: 1).toHost(shape: []))
189212
#endif // !TPU
190213
}
191214

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ TensorADTests.testAllBackends("mean") {
121121
expectEqual(expected, meanGradAlongAxes(input))
122122
}
123123

124+
TensorADTests.testAllBackends("variance") {
125+
let varianceGradScalar = gradient { (a: Tensor<Float>) in a.variance() }
126+
// let varianceGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.variance(squeezingAxes: 0, 1) }
127+
let varianceGradAlongAxes = gradient { (a: Tensor<Float>) in a.variance(alongAxes: 0, 1) }
128+
129+
let input: Tensor<Float> = [[1, 2], [3, 4]]
130+
let expected: Tensor<Float> = [[-0.75, -0.25], [0.25, 0.75]]
131+
expectEqual(expected, varianceGradScalar(input))
132+
// expectEqual(expected, varianceGradSqueezingAxes(input))
133+
expectEqual(expected, varianceGradAlongAxes(input))
134+
}
135+
124136
TensorADTests.testAllBackends("expandingShape") {
125137
let f1 = { (a: Tensor<Float>) in a.expandingShape(at: 0).squared() }
126138
let f2 = { (a: Tensor<Float>) in a.squared().expandingShape(at: 0) }

0 commit comments

Comments
 (0)