Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

TF-701 Small cleanup of _lazyTensor to _lazyTensorHandle #920

Merged
merged 1 commit into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Tests/TensorFlowTests/LazyTensorHandleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ final class LazyTensorHandleTests: XCTestCase {
private func checkConversions<T: _LazyTensorCompatible>(_ x: T) {
let concreteLazyX = x._concreteLazyTensor
let concreteInputLazyX = x._concreteInputLazyTensor
XCTAssertFalse(isSymbolic(concreteLazyX._lazyTensor))
XCTAssertFalse(isSymbolic(concreteInputLazyX._lazyTensor))
XCTAssertFalse(isMaterializedConcrete(concreteLazyX._lazyTensor))
XCTAssertTrue(isMaterializedConcrete(concreteInputLazyX._lazyTensor))
XCTAssertFalse(isSymbolic(concreteLazyX._lazyTensorHandle))
XCTAssertFalse(isSymbolic(concreteInputLazyX._lazyTensorHandle))
XCTAssertFalse(isMaterializedConcrete(concreteLazyX._lazyTensorHandle))
XCTAssertTrue(isMaterializedConcrete(concreteInputLazyX._lazyTensorHandle))
}

func testTensorToLazyTensorConversions() {
Expand Down
10 changes: 5 additions & 5 deletions Tests/TensorFlowTests/LazyTensorShapeInferenceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ final class LazyTensorShapeInferenceTests: LazyTensorTestCase {
let b = Tensor<Float>(shape: [1, 3], scalars: [1.0, 2.0, 3.0])
let c = Tensor<Float>(shape: [1, 3], scalars: [4.0, 5.0, 6.0])
let w = a * b
let wLazyTensorOperation = w._lazyTensor!.lazyTensorOperation!
let wLazyTensorOperation = w._lazyTensorHandle!.lazyTensorOperation!
let x = w * c
let xLazyTensorOperation = x._lazyTensor!.lazyTensorOperation!
let xLazyTensorOperation = x._lazyTensorHandle!.lazyTensorOperation!

// Make sure that `w` and `x` are not materialized.
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
Expand Down Expand Up @@ -60,7 +60,7 @@ final class LazyTensorShapeInferenceTests: LazyTensorTestCase {
let a = Tensor<Float>(shape: [3, 1], scalars: [1.0, 2.0, 3.0])
let b = a.reshaped(toShape: [1, 3])

let bLazyTensorOperation = b._lazyTensor!.lazyTensorOperation!
let bLazyTensorOperation = b._lazyTensorHandle!.lazyTensorOperation!
XCTAssertFalse(bLazyTensorOperation.isMaterialized)

let bShape = b.shape
Expand All @@ -69,7 +69,7 @@ final class LazyTensorShapeInferenceTests: LazyTensorTestCase {
XCTAssertFalse(bLazyTensorOperation.isMaterialized)

let c = Tensor<Float>(repeating: 5, shape: [4, 5, 6])
let cLazyTensorOperation = c._lazyTensor!.lazyTensorOperation!
let cLazyTensorOperation = c._lazyTensorHandle!.lazyTensorOperation!
XCTAssertFalse(cLazyTensorOperation.isMaterialized)

let cShape = c.shape
Expand All @@ -91,7 +91,7 @@ final class LazyTensorShapeInferenceTests: LazyTensorTestCase {
let dims = a + b
let m = _Raw.fill(dims: dims, value: Tensor<Float>(1.0))
let result = _Raw.matMul(m, m)
let mLazyTensorOperation = m._lazyTensor!.lazyTensorOperation!
let mLazyTensorOperation = m._lazyTensorHandle!.lazyTensorOperation!
// Note that we have not triggered materialization yet. So, it should not have happened
// implicitly during shape inference.
XCTAssertFalse(mLazyTensorOperation.isMaterialized)
Expand Down
14 changes: 7 additions & 7 deletions Tests/TensorFlowTests/LazyTensorTestHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LazyTensorTestCase: XCTestCase {

protocol _LazyTensorCompatible {
/// The underlying `LazyTensorHandle` (if any).
var _lazyTensor: LazyTensorHandle? { get }
var _lazyTensorHandle: LazyTensorHandle? { get }

/// Returns `Self` that wraps a concrete `LazyTensorHandle`.
/// (Triggers materialization if needed.)
Expand All @@ -47,7 +47,7 @@ protocol _LazyTensorCompatible {
}

extension _AnyTensorHandle {
var _lazyTensor: LazyTensorHandle? {
var _lazyTensorHandle: LazyTensorHandle? {
if let handle = self as? LazyTensorHandle {
return handle
} else {
Expand All @@ -58,35 +58,35 @@ extension _AnyTensorHandle {
}

extension TensorHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
var _lazyTensorHandle: LazyTensorHandle? { handle._lazyTensorHandle }
public var _concreteLazyTensor: TensorHandle {
TensorHandle(handle: handle._concreteLazyTensor)
}
}

extension Tensor: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
var _lazyTensorHandle: LazyTensorHandle? { handle._lazyTensorHandle }
public var _concreteLazyTensor: Tensor {
Tensor(handle: handle._concreteLazyTensor)
}
}

extension StringTensor: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
var _lazyTensorHandle: LazyTensorHandle? { handle._lazyTensorHandle }
public var _concreteLazyTensor: StringTensor {
StringTensor(handle: handle._concreteLazyTensor)
}
}

extension VariantHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
var _lazyTensorHandle: LazyTensorHandle? { handle._lazyTensorHandle }
public var _concreteLazyTensor: VariantHandle {
VariantHandle(handle: handle._concreteLazyTensor)
}
}

extension ResourceHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
var _lazyTensorHandle: LazyTensorHandle? { handle._lazyTensorHandle }
public var _concreteLazyTensor: ResourceHandle {
ResourceHandle(handle: handle._concreteLazyTensor)
}
Expand Down