Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ceccc9f

Browse files
committed
Change to use iterator so that we can check the result of execution.
1 parent 4961c03 commit ceccc9f

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,32 @@ final class LazyTensorEvaluationTests: XCTestCase {
8787
}
8888

8989
func testNoOutputOperations() {
90-
let a = StringTensor("Hello ")
91-
let b = StringTensor("World")
92-
let c = Raw.add(a, b)
93-
// c is not materialized yet.
94-
XCTAssertFalse(isMaterialized(c.handle.handle))
95-
Raw.printV2(c)
96-
// c is materialized now as printV2 would be executed.
97-
XCTAssertTrue(isMaterialized(c.handle.handle))
90+
let elements: Tensor<Int32> = [0, 1, 2]
91+
// let elements2: Tensor<Int32> = [10, 11, 12]
92+
let outputTypes = [Int32.tensorFlowDataType]
93+
let outputShapes: [TensorShape?] = [nil]
94+
let dataset: VariantHandle = Raw.tensorSliceDataset(
95+
components: [elements],
96+
outputShapes: [nil]
97+
)
98+
let iterator: ResourceHandle = Raw.iteratorV2(sharedName: "blah",
99+
container: "earth", outputTypes: outputTypes, outputShapes: outputShapes
100+
)
101+
// `dataset` and `iterator` should not be materialized yet.
102+
XCTAssertFalse(isMaterialized(dataset.handle))
103+
XCTAssertFalse(isMaterialized(iterator.handle))
104+
105+
Raw.makeIterator(dataset: dataset, iterator: iterator)
106+
// `dataset` and `iterator` should be materialized now as
107+
// makeIterator executes.
108+
XCTAssertTrue(isMaterialized(dataset.handle))
109+
XCTAssertTrue(isMaterialized(iterator.handle))
110+
111+
// The following won't work unless the makeIterator runs.
112+
let next: TensorHandle<Int32> = Raw.iteratorGetNext(
113+
iterator: iterator, outputShapes: outputShapes
114+
)
115+
XCTAssertEqual(Tensor(handle: next).scalarized(), 0)
98116
}
99117

100118
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {

0 commit comments

Comments
 (0)