Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Move the computation of shape and rank to TFETensorHandle. #401

Merged
merged 1 commit into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class LazyTensorHandle: _AnyTensorHandle {
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
}
}

/// The number of dimensions of the underlying `Tensor`.
@inlinable
var rank: Int { _tfeTensorHandle.rank }

/// The shape of the underlying `Tensor`.
@inlinable
var shape: TensorShape { _tfeTensorHandle.shape }

// Liveness tracking for LazyTensorOperations
//
Expand Down
17 changes: 2 additions & 15 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,14 @@ public extension Tensor {
@inlinable
var rank: Int {
@_semantics("autodiff.nonvarying")
get {
let status = _ExecutionContext.global.status
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
checkOk(status)
return Int(rank)
}
get { handle.rank }
}

/// The shape of the `Tensor`.
@inlinable
var shape: TensorShape {
@_semantics("autodiff.nonvarying")
get {
let status = _ExecutionContext.global.status
let dims: [Int] = (0..<Int32(rank)).map { i in
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)
checkOk(status)
return Int(dim)
}
return TensorShape(dims)
}
get { handle.shape }
}

/// The number of scalars in the `Tensor`.
Expand Down
48 changes: 46 additions & 2 deletions Sources/TensorFlow/Core/TensorHandle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import CTensorFlow
// protocol to workaround bug TF-527. When it is fixed, we should remove `: class`.
public protocol _AnyTensorHandle: class {
var _tfeTensorHandle: TFETensorHandle { get }
var rank: Int { get }
var shape: TensorShape { get }
}

extension _AnyTensorHandle {
Expand Down Expand Up @@ -49,14 +51,40 @@ public class TFETensorHandle: _AnyTensorHandle {
Context.local.globalTensorCount -= 1
debugLog("Returning from deinit of TensorHandle.")
}
}

/// The number of dimensions of the underlying `Tensor`.
@inlinable
public var rank: Int {
@_semantics("autodiff.nonvarying")
get {
let status = _ExecutionContext.global.status
let rank = TFE_TensorHandleNumDims(_cTensorHandle, status)
checkOk(status)
return Int(rank)
}
}

/// The shape of the underlying `Tensor`.
@inlinable
public var shape: TensorShape {
@_semantics("autodiff.nonvarying")
get {
let status = _ExecutionContext.global.status
let dims: [Int] = (0..<Int32(rank)).map { i in
let dim = TFE_TensorHandleDim(_cTensorHandle, i, status)
checkOk(status)
return Int(dim)
}
return TensorShape(dims)
}
}
}

/// `TensorHandle` is the type used by ops. It includes a `Scalar` type, which
/// compiler internals can use to determine the datatypes of parameters when
/// they are extracted into a tensor program.
public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
let handle: _AnyTensorHandle
@usableFromInline let handle: _AnyTensorHandle

public var _cTensorHandle: CTensorHandle { handle._cTensorHandle }

Expand Down Expand Up @@ -128,6 +156,22 @@ extension TensorHandle where Scalar: TensorFlowScalar {
}
}

extension TensorHandle {
/// The number of dimensions of the `Tensor`.
@inlinable
public var rank: Int {
@_semantics("autodiff.nonvarying")
get { handle.rank }
}

/// The shape of the `Tensor`.
@inlinable
public var shape: TensorShape {
@_semantics("autodiff.nonvarying")
get { handle.shape }
}
}

internal extension TensorHandle {
/// Create a `ShapedArray` with contents of the underlying `TensorHandle`. If the `TensorHandle`
/// is on the accelerator, it will be copied to the host.
Expand Down