Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add an API to get the underlying LazyTensorOperation of a LazyTensorHandle. #405

Merged
merged 3 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class LazyTensorHandle: _AnyTensorHandle {
/// The shape of the underlying `Tensor`.
@inlinable
var shape: TensorShape { _tfeTensorHandle.shape }

/// Returns the underlying `LazyTensorOperation` if this is a symbolic `LazyTensorHandle`.
var lazyTensorOperation: LazyTensorOperation? {
switch handle {
case .symbolic(let op, _, _): return op
case .concrete(_): return nil
}
}

// Liveness tracking for LazyTensorOperations
//
Expand Down
16 changes: 16 additions & 0 deletions Tests/TensorFlowTests/LazyTensorHandleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ final class LazyTensorHandleTests: XCTestCase {
XCTAssertEqual(liveSymTensor.description, "%0.0*")
}

func testLazyTensorOperationProperty() {
let zero = Tensor<Float>(0.0)
let zeroTFEHandle = zero.handle.handle._tfeTensorHandle
let concTensor = LazyTensorHandle(zeroTFEHandle)
XCTAssertNil(concTensor.lazyTensorOperation)

let op = LazyTensorOperation(
_id: "0", name: "IdentityN", outputCount: 3)
let symTensor = LazyTensorHandle(_lazy: op, index: 0)
let lazyTensorOperation = symTensor.lazyTensorOperation
XCTAssertNotNil(lazyTensorOperation)
// Checks that returned value is the same as the one that we passed in.
XCTAssertTrue(lazyTensorOperation === op)
}

func testLivenessTracking() {
func assertLive(_ expectedLive: [LazyTensorOperation]) {
var actualLiveOps: Set<LazyTensorOperationRef> = []
Expand Down Expand Up @@ -152,6 +167,7 @@ final class LazyTensorHandleTests: XCTestCase {

static var allTests = [
("testConstructions", testConstructions),
("testLazyTensorOperationProperty", testLazyTensorOperationProperty),
("testLivenessTracking", testLivenessTracking),
("testTensorToLazyTensorConversions", testTensorToLazyTensorConversions)
]
Expand Down