Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5ffe1c3

Browse files
authored
Add an API to get the underlying LazyTensorOperation of a LazyTensorHandle. (#405)
1 parent 05fad16 commit 5ffe1c3

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ class LazyTensorHandle: _AnyTensorHandle {
7373
/// The shape of the underlying `Tensor`.
7474
@inlinable
7575
var shape: TensorShape { _tfeTensorHandle.shape }
76+
77+
/// Returns the underlying `LazyTensorOperation` if this is a symbolic `LazyTensorHandle`.
78+
var lazyTensorOperation: LazyTensorOperation? {
79+
switch handle {
80+
case .symbolic(let op, _, _): return op
81+
case .concrete(_): return nil
82+
}
83+
}
7684

7785
// Liveness tracking for LazyTensorOperations
7886
//

Tests/TensorFlowTests/LazyTensorHandleTests.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ final class LazyTensorHandleTests: XCTestCase {
5252
XCTAssertEqual(liveSymTensor.description, "%0.0*")
5353
}
5454

55+
func testLazyTensorOperationProperty() {
56+
let zero = Tensor<Float>(0.0)
57+
let zeroTFEHandle = zero.handle.handle._tfeTensorHandle
58+
let concTensor = LazyTensorHandle(zeroTFEHandle)
59+
XCTAssertNil(concTensor.lazyTensorOperation)
60+
61+
let op = LazyTensorOperation(
62+
_id: "0", name: "IdentityN", outputCount: 3)
63+
let symTensor = LazyTensorHandle(_lazy: op, index: 0)
64+
let lazyTensorOperation = symTensor.lazyTensorOperation
65+
XCTAssertNotNil(lazyTensorOperation)
66+
// Checks that returned value is the same as the one that we passed in.
67+
XCTAssertTrue(lazyTensorOperation === op)
68+
}
69+
5570
func testLivenessTracking() {
5671
func assertLive(_ expectedLive: [LazyTensorOperation]) {
5772
var actualLiveOps: Set<LazyTensorOperationRef> = []
@@ -152,6 +167,7 @@ final class LazyTensorHandleTests: XCTestCase {
152167

153168
static var allTests = [
154169
("testConstructions", testConstructions),
170+
("testLazyTensorOperationProperty", testLazyTensorOperationProperty),
155171
("testLivenessTracking", testLivenessTracking),
156172
("testTensorToLazyTensorConversions", testTensorToLazyTensorConversions)
157173
]

0 commit comments

Comments
 (0)