@@ -33,7 +33,7 @@ final class LazyTensorTraceTests: XCTestCase {
33
33
let b = Tensor < Float > ( 2.0 )
34
34
let c = Tensor < Float > ( 3.0 )
35
35
let w = a + b * c
36
- XCTAssertEqual ( lazyTrace ( w) ! . description,
36
+ XCTAssertEqual ( lazyTrace ( w) . description,
37
37
"""
38
38
lazyTrace_5() -> (%4) {
39
39
%0 = Const[dtype: float, value: 10.0]()
@@ -55,7 +55,7 @@ final class LazyTensorTraceTests: XCTestCase {
55
55
let w = a + b + c
56
56
let y = w * c
57
57
let z = y / ( w - c)
58
- XCTAssertEqual ( lazyTrace ( z) ! . description,
58
+ XCTAssertEqual ( lazyTrace ( z) . description,
59
59
"""
60
60
lazyTrace_8() -> (%4, %5, %7) {
61
61
%0 = Const[dtype: float, value: 10.0]()
@@ -71,7 +71,7 @@ final class LazyTensorTraceTests: XCTestCase {
71
71
72
72
// Note that we only pick operations on which the lazy tensor in
73
73
// question depends on.
74
- XCTAssertEqual ( lazyTrace ( y) ! . description,
74
+ XCTAssertEqual ( lazyTrace ( y) . description,
75
75
"""
76
76
lazyTrace_6() -> (%4, %5) {
77
77
%0 = Const[dtype: float, value: 10.0]()
@@ -84,21 +84,46 @@ final class LazyTensorTraceTests: XCTestCase {
84
84
""" )
85
85
}
86
86
87
+ func testMultipleTargets( ) {
88
+ // This test checks that *only* the operations that correspond to `w`,
89
+ // `y` and `z` are marked as outputs. Specifcally, the intermediate
90
+ // operations in the trace are not marked as outputs.
91
+ let a = Tensor < Float > ( 1.0 )
92
+ let b = Tensor < Float > ( 2.0 )
93
+ let c = Tensor < Float > ( 3.0 )
94
+ let d = Tensor < Float > ( 4.0 )
95
+ let w = a + b
96
+ let x = c + d
97
+ let lazyOps = [ w, x] . map { self . lazyTensorOperation ( $0) ! }
98
+ XCTAssertEqual ( LazyTensorTrace ( lazyOps) . description,
99
+ """
100
+ lazyTrace_6() -> (%2, %5) {
101
+ %0 = Const[dtype: float, value: 1.0]()
102
+ %1 = Const[dtype: float, value: 2.0]()
103
+ %2 = Add[T: float](%0, %1)
104
+ %3 = Const[dtype: float, value: 3.0]()
105
+ %4 = Const[dtype: float, value: 4.0]()
106
+ %5 = Add[T: float](%3, %4)
107
+ }
108
+ """ )
109
+ }
110
+
111
+
87
112
func testSimpleControlFlow( ) {
88
113
let a = Tensor < Float > ( 5.0 )
89
114
let addOrMul = { ( useAdd: Bool , a: Tensor < Float > ) in
90
115
useAdd ? ( a + a) : ( a * a)
91
116
}
92
117
let add = addOrMul ( /*useAdd:*/true , a)
93
- XCTAssertEqual ( lazyTrace ( add) ! . description,
118
+ XCTAssertEqual ( lazyTrace ( add) . description,
94
119
"""
95
120
lazyTrace_2() -> (%1) {
96
121
%0 = Const[dtype: float, value: 5.0]()
97
122
%1 = Add[T: float](%0, %0)
98
123
}
99
124
""" )
100
125
let mul = addOrMul ( /*useAdd:*/false , a)
101
- XCTAssertEqual ( lazyTrace ( mul) ! . description,
126
+ XCTAssertEqual ( lazyTrace ( mul) . description,
102
127
"""
103
128
lazyTrace_2() -> (%1) {
104
129
%0 = Const[dtype: float, value: 5.0]()
@@ -115,7 +140,7 @@ final class LazyTensorTraceTests: XCTestCase {
115
140
// be burnt into the trace as a constant.
116
141
let lazyA = a. _concreteLazyTensor
117
142
let w1 = lazyA * b
118
- let w1Trace = lazyTrace ( w1) !
143
+ let w1Trace = lazyTrace ( w1)
119
144
XCTAssertEqual ( w1Trace. description,
120
145
"""
121
146
lazyTrace_3() -> (%2) {
@@ -130,7 +155,7 @@ final class LazyTensorTraceTests: XCTestCase {
130
155
// be promoted to an input for the trace.
131
156
let inputLazyA = a. _concreteInputLazyTensor
132
157
let w2 = inputLazyA * b
133
- let w2Trace = lazyTrace ( w2) !
158
+ let w2Trace = lazyTrace ( w2)
134
159
XCTAssertEqual ( w2Trace. description,
135
160
"""
136
161
lazyTrace_3(%0: float) -> (%2) {
@@ -151,7 +176,7 @@ final class LazyTensorTraceTests: XCTestCase {
151
176
let z = y * c
152
177
153
178
XCTAssertEqual (
154
- lazyTrace ( y) ! . description,
179
+ lazyTrace ( y) . description,
155
180
"""
156
181
lazyTrace_3() -> (%2) {
157
182
%0 = Const[dtype: float, value: 1.0]()
@@ -163,7 +188,7 @@ final class LazyTensorTraceTests: XCTestCase {
163
188
164
189
/// Now that `y` is materialized and a constant,
165
190
/// the trace for `z` will use that as a constant.
166
- let zTrace = lazyTrace ( z) !
191
+ let zTrace = lazyTrace ( z)
167
192
XCTAssertEqual (
168
193
zTrace. description,
169
194
"""
@@ -178,9 +203,9 @@ final class LazyTensorTraceTests: XCTestCase {
178
203
XCTAssertEqual ( z. scalarized ( ) , 9.0 )
179
204
}
180
205
181
- private func lazyTrace < T: TensorFlowScalar > (
206
+ private func lazyTensorOperation < T: TensorFlowScalar > (
182
207
_ input: Tensor < T >
183
- ) -> LazyTensorTrace ? {
208
+ ) -> LazyTensorOperation ? {
184
209
let tensor = input. handle. handle
185
210
guard let lazyTensor = tensor as? LazyTensorHandle else {
186
211
XCTFail ( " Trying to get lazy trace for a non-lazy tensor. " )
@@ -190,12 +215,17 @@ final class LazyTensorTraceTests: XCTestCase {
190
215
XCTFail ( " Cannot get lazy trace for a concrete tensor. " )
191
216
return nil
192
217
}
193
- return LazyTensorTrace ( lazyOp)
218
+ return lazyOp
219
+ }
220
+
221
+ private func lazyTrace< T: TensorFlowScalar > ( _ input: Tensor < T > ) -> LazyTensorTrace {
222
+ return LazyTensorTrace ( lazyTensorOperation ( input) !)
194
223
}
195
224
196
225
static var allTests = [
197
226
( " testSingleLiveTensor " , testSingleLiveTensor) ,
198
227
( " testMultipleLiveTensors " , testMultipleLiveTensors) ,
228
+ ( " testMultipleTargets " , testMultipleTargets) ,
199
229
( " testSimpleControlFlow " , testSimpleControlFlow) ,
200
230
( " testManualConstPromotion " , testManualConstPromotion) ,
201
231
( " testConstPromotion " , testConstPromotion)
0 commit comments