Skip to content

Commit b79adfe

Browse files
authored
---
yaml --- r: 312283 b: refs/heads/tensorflow-merge c: 71cd119 h: refs/heads/master i: 312281: 204b1be 312279: c76e7bb
1 parent 2dce3f2 commit b79adfe

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: 052c301eea9b1f903e1ad18c3f82af4701c2ab4e
1382+
refs/heads/tensorflow-merge: 71cd119e77e565961224d0a92fe853c2945c7c81
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/stdlib/public/TensorFlow/Ops.swift

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,16 +1259,14 @@ public extension Tensor where Scalar : Numeric {
12591259
where Scalar : TensorFlowFloatingPoint
12601260
)
12611261
func sum() -> Tensor {
1262-
let axes = Tensor<Int32>(rangeFrom: 0, to: rank, stride: 1)
1263-
return Raw.sum(self, reductionIndices: axes)
1262+
return Raw.sum(flattened(), reductionIndices: [0])
12641263
}
12651264

12661265
// NOTE: This overload is necessary, otherwise `sum()` would refer
12671266
// to the variadic method `sum(squeezingAxes:)` with zero indices.
12681267
@inlinable @inline(__always)
12691268
func product() -> Tensor {
1270-
let axes = Tensor<Int32>(rangeFrom: 0, to: rank, stride: 1)
1271-
return Raw.prod(self, reductionIndices: axes)
1269+
return Raw.prod(flattened(), reductionIndices: [0])
12721270
}
12731271

12741272
// NOTE: This overload is necessary, otherwise `mean()` would refer
@@ -1279,8 +1277,7 @@ public extension Tensor where Scalar : Numeric {
12791277
)
12801278
@inlinable @inline(__always)
12811279
func mean() -> Tensor {
1282-
let axes = Tensor<Int32>(rangeFrom: 0, to: rank, stride: 1)
1283-
return Raw.mean(self, reductionIndices: axes)
1280+
return Raw.mean(flattened(), reductionIndices: [0])
12841281
}
12851282

12861283
// NOTE: This overload is necessary, otherwise `mean()` would refer

branches/tensorflow-merge/test/TensorFlowRuntime/tensor.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ TensorTests.testAllBackends("Reduction") {
179179
// 2 x 5
180180
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
181181
expectEqual(Tensor(30), x.sum().toHost(shape: []))
182+
expectEqual(0, x.rank)
182183
expectEqual(Tensor(shape: [5], scalars: [2, 4, 6, 8, 10]),
183184
x.sum(squeezingAxes: 0).toHost(shape: []))
184185
expectEqual(Tensor(shape: [1, 5], scalars: [2, 4, 6, 8, 10]),

0 commit comments

Comments
 (0)