@@ -86,8 +86,44 @@ final class LazyTensorEvaluationTests: XCTestCase {
86
86
XCTAssertTrue ( isMaterialized ( sum) )
87
87
}
88
88
89
+ struct SimpleOutput : TensorGroup {
90
+ let a : TensorHandle < Int32 >
91
+ let b : TensorHandle < Int32 >
92
+ }
93
+
94
+ func testNoOutputOperations( ) {
95
+ let elements1 : Tensor < Int32 > = [ 0 , 1 , 2 ]
96
+ let elements2 : Tensor < Int32 > = [ 10 , 11 , 12 ]
97
+ let outputTypes = [ Int32 . tensorFlowDataType, Int32 . tensorFlowDataType]
98
+ let outputShapes : [ TensorShape ? ] = [ nil , nil ]
99
+ let dataset : VariantHandle = Raw . tensorSliceDataset (
100
+ components: [ elements1, elements2] ,
101
+ outputShapes: outputShapes
102
+ )
103
+ let iterator : ResourceHandle = Raw . iteratorV2 ( sharedName: " blah " ,
104
+ container: " earth " , outputTypes: outputTypes, outputShapes: outputShapes
105
+ )
106
+ // `dataset` and `iterator` should not be materialized yet.
107
+ XCTAssertFalse ( isMaterialized ( dataset. handle) )
108
+ XCTAssertFalse ( isMaterialized ( iterator. handle) )
109
+ Raw . makeIterator ( dataset: dataset, iterator: iterator)
110
+
111
+ // `dataset` and `iterator` should be materialized now as
112
+ // makeIterator executes.
113
+ XCTAssertTrue ( isMaterialized ( dataset. handle) )
114
+ XCTAssertTrue ( isMaterialized ( iterator. handle) )
115
+ let next : SimpleOutput = Raw . iteratorGetNext (
116
+ iterator: iterator, outputShapes: outputShapes
117
+ )
118
+ XCTAssertEqual ( Tensor ( handle: next. a) . scalarized ( ) , 0 )
119
+ XCTAssertEqual ( Tensor ( handle: next. b) . scalarized ( ) , 10 )
120
+ }
121
+
89
122
private func isMaterialized< T: TensorFlowScalar > ( _ input: Tensor < T > ) -> Bool {
90
- let tensor = input. handle. handle
123
+ return isMaterialized ( input. handle. handle)
124
+ }
125
+
126
+ private func isMaterialized( _ tensor: _AnyTensorHandle ) -> Bool {
91
127
guard let lazyTensor = tensor as? LazyTensor else { return true }
92
128
switch lazyTensor. handle {
93
129
case . symbolic( let op, _, _) : return op. outputs != nil
@@ -100,6 +136,7 @@ final class LazyTensorEvaluationTests: XCTestCase {
100
136
( " testMultipleMaterializations " , testMultipleMaterializations) ,
101
137
( " testSimpleControlFlow " , testSimpleControlFlow) ,
102
138
( " testSimpleLoop " , testSimpleLoop) ,
139
+ ( " testNoOutputOperations " , testNoOutputOperations)
103
140
]
104
141
}
105
142
0 commit comments