Skip to content

Commit 49ff2dd

Browse files
committed
[TF] Re-implement tensor shape/rank operators using TFE runtime APIs.
1 parent 51f82c5 commit 49ff2dd

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ public final class _ExecutionContext {
554554
@usableFromInline let eagerContext: CTFEContext
555555

556556
/// The status for checking TensorFlow errors.
557-
private let status: CTFStatus = TF_NewStatus()
557+
@usableFromInline let status: CTFStatus = TF_NewStatus()
558558

559559
/// The mutex for preventing potential concurrent access.
560560
private var mutex: pthread_mutex_t = pthread_mutex_t()

stdlib/public/TensorFlow/Tensor.swift

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ public extension Tensor {
367367
var rank: Int {
368368
@_semantics("autodiff.nonvarying")
369369
get {
370-
return Int(rankTensor.scalar!)
370+
let status = _ExecutionContext.global.status
371+
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
372+
checkOk(status)
373+
return Int(rank)
371374
}
372375
}
373376

@@ -376,14 +379,26 @@ public extension Tensor {
376379
var shape: TensorShape {
377380
@_semantics("autodiff.nonvarying")
378381
get {
379-
return TensorShape(shapeTensor.scalars.map(Int.init))
382+
let status = _ExecutionContext.global.status
383+
let dims: [Int] = (0..<Int32(rank)).map { i in
384+
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)
385+
checkOk(status)
386+
return Int(dim)
387+
}
388+
return TensorShape(dims)
380389
}
381390
}
382391

383392
/// The number of scalars in the `Tensor`.
384393
@inlinable
385394
var scalarCount: Int {
386-
return Int(scalarCountTensor.scalar!)
395+
@_semantics("autodiff.nonvarying")
396+
get {
397+
let status = _ExecutionContext.global.status
398+
let size = TFE_TensorHandleNumElements(handle._cTensorHandle, status)
399+
checkOk(status)
400+
return Int(size)
401+
}
387402
}
388403
}
389404

0 commit comments

Comments
 (0)