Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 728a2d6

Browse files
authored
Move the computation of shape and rank to TFETensorHandle. (#401)
This is a necessary step to incorporate shape inference into `LazyTensorHandle`.
1 parent 03d8ff4 commit 728a2d6

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ class LazyTensorHandle: _AnyTensorHandle {
6565
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
6666
}
6767
}
68+
69+
/// The number of dimensions of the underlying `Tensor`.
70+
@inlinable
71+
var rank: Int { _tfeTensorHandle.rank }
72+
73+
/// The shape of the underlying `Tensor`.
74+
@inlinable
75+
var shape: TensorShape { _tfeTensorHandle.shape }
6876

6977
// Liveness tracking for LazyTensorOperations
7078
//

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,14 @@ public extension Tensor {
5454
@inlinable
5555
var rank: Int {
5656
@_semantics("autodiff.nonvarying")
57-
get {
58-
let status = _ExecutionContext.global.status
59-
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
60-
checkOk(status)
61-
return Int(rank)
62-
}
57+
get { handle.rank }
6358
}
6459

6560
/// The shape of the `Tensor`.
6661
@inlinable
6762
var shape: TensorShape {
6863
@_semantics("autodiff.nonvarying")
69-
get {
70-
let status = _ExecutionContext.global.status
71-
let dims: [Int] = (0..<Int32(rank)).map { i in
72-
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)
73-
checkOk(status)
74-
return Int(dim)
75-
}
76-
return TensorShape(dims)
77-
}
64+
get { handle.shape }
7865
}
7966

8067
/// The number of scalars in the `Tensor`.

Sources/TensorFlow/Core/TensorHandle.swift

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import CTensorFlow
2222
// protocol to workaround bug TF-527. When it is fixed, we should remove `: class`.
2323
public protocol _AnyTensorHandle: class {
2424
var _tfeTensorHandle: TFETensorHandle { get }
25+
var rank: Int { get }
26+
var shape: TensorShape { get }
2527
}
2628

2729
extension _AnyTensorHandle {
@@ -49,14 +51,40 @@ public class TFETensorHandle: _AnyTensorHandle {
4951
Context.local.globalTensorCount -= 1
5052
debugLog("Returning from deinit of TensorHandle.")
5153
}
52-
}
5354

55+
/// The number of dimensions of the underlying `Tensor`.
56+
@inlinable
57+
public var rank: Int {
58+
@_semantics("autodiff.nonvarying")
59+
get {
60+
let status = _ExecutionContext.global.status
61+
let rank = TFE_TensorHandleNumDims(_cTensorHandle, status)
62+
checkOk(status)
63+
return Int(rank)
64+
}
65+
}
66+
67+
/// The shape of the underlying `Tensor`.
68+
@inlinable
69+
public var shape: TensorShape {
70+
@_semantics("autodiff.nonvarying")
71+
get {
72+
let status = _ExecutionContext.global.status
73+
let dims: [Int] = (0..<Int32(rank)).map { i in
74+
let dim = TFE_TensorHandleDim(_cTensorHandle, i, status)
75+
checkOk(status)
76+
return Int(dim)
77+
}
78+
return TensorShape(dims)
79+
}
80+
}
81+
}
5482

5583
/// `TensorHandle` is the type used by ops. It includes a `Scalar` type, which
5684
/// compiler internals can use to determine the datatypes of parameters when
5785
/// they are extracted into a tensor program.
5886
public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
59-
let handle: _AnyTensorHandle
87+
@usableFromInline let handle: _AnyTensorHandle
6088

6189
public var _cTensorHandle: CTensorHandle { handle._cTensorHandle }
6290

@@ -128,6 +156,22 @@ extension TensorHandle where Scalar: TensorFlowScalar {
128156
}
129157
}
130158

159+
extension TensorHandle {
160+
/// The number of dimensions of the `Tensor`.
161+
@inlinable
162+
public var rank: Int {
163+
@_semantics("autodiff.nonvarying")
164+
get { handle.rank }
165+
}
166+
167+
/// The shape of the `Tensor`.
168+
@inlinable
169+
public var shape: TensorShape {
170+
@_semantics("autodiff.nonvarying")
171+
get { handle.shape }
172+
}
173+
}
174+
131175
internal extension TensorHandle {
132176
/// Create a `ShapedArray` with contents of the underlying `TensorHandle`. If the `TensorHandle`
133177
/// is on the accelerator, it will be copied to the host.

0 commit comments

Comments
 (0)