@@ -33,7 +33,7 @@ final class LazyTensorTraceTests: XCTestCase {
33
33
let b = Tensor < Float > ( 2.0 )
34
34
let c = Tensor < Float > ( 3.0 )
35
35
let w = a + b * c
36
- XCTAssertEqual ( lazyTrace ( w) ! . description,
36
+ XCTAssertEqual ( lazyTrace ( w) . description,
37
37
"""
38
38
lazyTrace_5() -> (%4) {
39
39
%0 = Const[dtype: float, value: 10.0]()
@@ -55,7 +55,7 @@ final class LazyTensorTraceTests: XCTestCase {
55
55
let w = a + b + c
56
56
let y = w * c
57
57
let z = y / ( w - c)
58
- XCTAssertEqual ( lazyTrace ( z) ! . description,
58
+ XCTAssertEqual ( lazyTrace ( z) . description,
59
59
"""
60
60
lazyTrace_8() -> (%4, %5, %7) {
61
61
%0 = Const[dtype: float, value: 10.0]()
@@ -71,7 +71,7 @@ final class LazyTensorTraceTests: XCTestCase {
71
71
72
72
// Note that we only pick operations on which the lazy tensor in
73
73
// question depends on.
74
- XCTAssertEqual ( lazyTrace ( y) ! . description,
74
+ XCTAssertEqual ( lazyTrace ( y) . description,
75
75
"""
76
76
lazyTrace_6() -> (%4, %5) {
77
77
%0 = Const[dtype: float, value: 10.0]()
@@ -84,21 +84,43 @@ final class LazyTensorTraceTests: XCTestCase {
84
84
""" )
85
85
}
86
86
87
+ func testMultipleTargets( ) {
88
+ let a = Tensor < Float > ( 1.0 )
89
+ let b = Tensor < Float > ( 2.0 )
90
+ let c = Tensor < Float > ( 3.0 )
91
+ let d = Tensor < Float > ( 4.0 )
92
+ let w = a + b
93
+ let x = c + d
94
+ let lazyOps = [ w, x] . map { self . lazyTensorOperation ( $0) ! }
95
+ XCTAssertEqual ( LazyTensorTrace ( lazyOps) . description,
96
+ """
97
+ lazyTrace_6() -> (%2, %5) {
98
+ %0 = Const[dtype: float, value: 1.0]()
99
+ %1 = Const[dtype: float, value: 2.0]()
100
+ %2 = Add[T: float](%0, %1)
101
+ %3 = Const[dtype: float, value: 3.0]()
102
+ %4 = Const[dtype: float, value: 4.0]()
103
+ %5 = Add[T: float](%3, %4)
104
+ }
105
+ """ )
106
+ }
107
+
108
+
87
109
func testSimpleControlFlow( ) {
88
110
let a = Tensor < Float > ( 5.0 )
89
111
let addOrMul = { ( useAdd: Bool , a: Tensor < Float > ) in
90
112
useAdd ? ( a + a) : ( a * a)
91
113
}
92
114
let add = addOrMul ( /*useAdd:*/true , a)
93
- XCTAssertEqual ( lazyTrace ( add) ! . description,
115
+ XCTAssertEqual ( lazyTrace ( add) . description,
94
116
"""
95
117
lazyTrace_2() -> (%1) {
96
118
%0 = Const[dtype: float, value: 5.0]()
97
119
%1 = Add[T: float](%0, %0)
98
120
}
99
121
""" )
100
122
let mul = addOrMul ( /*useAdd:*/false , a)
101
- XCTAssertEqual ( lazyTrace ( mul) ! . description,
123
+ XCTAssertEqual ( lazyTrace ( mul) . description,
102
124
"""
103
125
lazyTrace_2() -> (%1) {
104
126
%0 = Const[dtype: float, value: 5.0]()
@@ -115,7 +137,7 @@ final class LazyTensorTraceTests: XCTestCase {
115
137
// be burnt into the trace as a constant.
116
138
let lazyA = a. _concreteLazyTensor
117
139
let w1 = lazyA * b
118
- let w1Trace = lazyTrace ( w1) !
140
+ let w1Trace = lazyTrace ( w1)
119
141
XCTAssertEqual ( w1Trace. description,
120
142
"""
121
143
lazyTrace_3() -> (%2) {
@@ -130,7 +152,7 @@ final class LazyTensorTraceTests: XCTestCase {
130
152
// be promoted to an input for the trace.
131
153
let inputLazyA = a. _concreteInputLazyTensor
132
154
let w2 = inputLazyA * b
133
- let w2Trace = lazyTrace ( w2) !
155
+ let w2Trace = lazyTrace ( w2)
134
156
XCTAssertEqual ( w2Trace. description,
135
157
"""
136
158
lazyTrace_3(%0: float) -> (%2) {
@@ -151,7 +173,7 @@ final class LazyTensorTraceTests: XCTestCase {
151
173
let z = y * c
152
174
153
175
XCTAssertEqual (
154
- lazyTrace ( y) ! . description,
176
+ lazyTrace ( y) . description,
155
177
"""
156
178
lazyTrace_3() -> (%2) {
157
179
%0 = Const[dtype: float, value: 1.0]()
@@ -163,7 +185,7 @@ final class LazyTensorTraceTests: XCTestCase {
163
185
164
186
/// Now that `y` is materialized and a constant,
165
187
/// the trace for `z` will use that as a constant.
166
- let zTrace = lazyTrace ( z) !
188
+ let zTrace = lazyTrace ( z)
167
189
XCTAssertEqual (
168
190
zTrace. description,
169
191
"""
@@ -178,9 +200,9 @@ final class LazyTensorTraceTests: XCTestCase {
178
200
XCTAssertEqual ( z. scalarized ( ) , 9.0 )
179
201
}
180
202
181
- private func lazyTrace < T: TensorFlowScalar > (
203
+ private func lazyTensorOperation < T: TensorFlowScalar > (
182
204
_ input: Tensor < T >
183
- ) -> LazyTensorTrace ? {
205
+ ) -> LazyTensorOperation ? {
184
206
let tensor = input. handle. handle
185
207
guard let lazyTensor = tensor as? LazyTensorHandle else {
186
208
XCTFail ( " Trying to get lazy trace for a non-lazy tensor. " )
@@ -190,12 +212,17 @@ final class LazyTensorTraceTests: XCTestCase {
190
212
XCTFail ( " Cannot get lazy trace for a concrete tensor. " )
191
213
return nil
192
214
}
193
- return LazyTensorTrace ( lazyOp)
215
+ return lazyOp
216
+ }
217
+
218
+ private func lazyTrace< T: TensorFlowScalar > ( _ input: Tensor < T > ) -> LazyTensorTrace {
219
+ return LazyTensorTrace ( lazyTensorOperation ( input) !)
194
220
}
195
221
196
222
static var allTests = [
197
223
( " testSingleLiveTensor " , testSingleLiveTensor) ,
198
224
( " testMultipleLiveTensors " , testMultipleLiveTensors) ,
225
+ ( " testMultipleTargets " , testMultipleTargets) ,
199
226
( " testSimpleControlFlow " , testSimpleControlFlow) ,
200
227
( " testManualConstPromotion " , testManualConstPromotion) ,
201
228
( " testConstPromotion " , testConstPromotion)
0 commit comments