Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 89b8548

Browse files
eaplataniosrxwei
authored andcommitted
Re-implement tensor shape/rank operators using TFE runtime APIs. (#141)
Ported from swiftlang/swift#24949. Gives a 98% speedup for `Tensor.rank` and 85% speedup for `Tensor.shape` (for a 2x3 tensor) on a MacBook Pro.
1 parent 19ed1e9 commit 89b8548

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ public final class _ExecutionContext {
551551
@usableFromInline let eagerContext: CTFEContext
552552

553553
/// The status for checking TensorFlow errors.
554-
private let status: CTFStatus = TF_NewStatus()
554+
@usableFromInline let status: CTFStatus = TF_NewStatus()
555555

556556
/// The mutex for preventing potential concurrent access.
557557
private var mutex: pthread_mutex_t = pthread_mutex_t()

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import CTensorFlow
16+
1517
infix operator .==: ComparisonPrecedence
1618
infix operator .!=: ComparisonPrecedence
1719

@@ -50,7 +52,10 @@ public extension Tensor {
5052
var rank: Int {
5153
@_semantics("autodiff.nonvarying")
5254
get {
53-
return Int(rankTensor.scalar!)
55+
let status = _ExecutionContext.global.status
56+
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
57+
checkOk(status)
58+
return Int(rank)
5459
}
5560
}
5661

@@ -59,15 +64,25 @@ public extension Tensor {
5964
var shape: TensorShape {
6065
@_semantics("autodiff.nonvarying")
6166
get {
62-
return TensorShape(shapeTensor.scalars.map(Int.init))
67+
let status = _ExecutionContext.global.status
68+
let dims: [Int] = (0..<Int32(rank)).map { i in
69+
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)
70+
checkOk(status)
71+
return Int(dim)
72+
}
73+
return TensorShape(dims)
6374
}
6475
}
6576

6677
/// The number of scalars in the `Tensor`.
6778
@inlinable
6879
var scalarCount: Int {
80+
@_semantics("autodiff.nonvarying")
6981
get {
70-
return Int(scalarCountTensor.scalar!)
82+
let status = _ExecutionContext.global.status
83+
let size = TFE_TensorHandleNumElements(handle._cTensorHandle, status)
84+
checkOk(status)
85+
return Int(size)
7186
}
7287
}
7388

0 commit comments

Comments
 (0)