Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit cf2791a

Browse files
committed
Change to use iterator so that we can check the result of execution.
1 parent 4961c03 commit cf2791a

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,37 @@ final class LazyTensorEvaluationTests: XCTestCase {
8686
XCTAssertTrue(isMaterialized(sum))
8787
}
8888

89+
struct SimpleOutput: TensorGroup {
90+
let a: TensorHandle<Int32>
91+
let b: TensorHandle<Int32>
92+
}
93+
8994
func testNoOutputOperations() {
90-
let a = StringTensor("Hello ")
91-
let b = StringTensor("World")
92-
let c = Raw.add(a, b)
93-
// c is not materialized yet.
94-
XCTAssertFalse(isMaterialized(c.handle.handle))
95-
Raw.printV2(c)
96-
// c is materialized now as printV2 would be executed.
97-
XCTAssertTrue(isMaterialized(c.handle.handle))
95+
let elements1: Tensor<Int32> = [0, 1, 2]
96+
let elements2: Tensor<Int32> = [10, 11, 12]
97+
let outputTypes = [Int32.tensorFlowDataType, Int32.tensorFlowDataType]
98+
let outputShapes: [TensorShape?] = [nil, nil]
99+
let dataset: VariantHandle = Raw.tensorSliceDataset(
100+
components: [elements1, elements2],
101+
outputShapes: outputShapes
102+
)
103+
let iterator: ResourceHandle = Raw.iteratorV2(sharedName: "blah",
104+
container: "earth", outputTypes: outputTypes, outputShapes: outputShapes
105+
)
106+
// `dataset` and `iterator` should not be materialized yet.
107+
XCTAssertFalse(isMaterialized(dataset.handle))
108+
XCTAssertFalse(isMaterialized(iterator.handle))
109+
Raw.makeIterator(dataset: dataset, iterator: iterator)
110+
111+
// `dataset` and `iterator` should be materialized now as
112+
// makeIterator executes.
113+
XCTAssertTrue(isMaterialized(dataset.handle))
114+
XCTAssertTrue(isMaterialized(iterator.handle))
115+
let next: SimpleOutput = Raw.iteratorGetNext(
116+
iterator: iterator, outputShapes: outputShapes
117+
)
118+
XCTAssertEqual(Tensor(handle: next.a).scalarized(), 0)
119+
XCTAssertEqual(Tensor(handle: next.b).scalarized(), 10)
98120
}
99121

100122
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {

0 commit comments

Comments
 (0)