Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a0a025c

Browse files
committed
Allow lazy traces to be extracted from multiple targets.
1 parent 97ca18b commit a0a025c

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class LazyTensorTrace {
2424
var outputs: [LazyTensorOperation] = []
2525
var originalOutputs: [LazyTensorOperation] = []
2626

27-
init(_ lazyOp: LazyTensorOperation) {
27+
init(_ lazyOperations: [LazyTensorOperation]) {
2828
// TODO: We only pick operations on which `lazyOp` depends on. Note that
2929
// there may be other live tensors that could also be materialized at
3030
// this time. e.g.,
@@ -34,10 +34,16 @@ class LazyTensorTrace {
3434
// `y = x + c` into the trace so that we don't have the overhead of creating
3535
// another trace when we need to materialize `y`.
3636
//
37-
_ = collectLazyOperation(lazyOp)
37+
for lazyOp in lazyOperations {
38+
_ = collectLazyOperation(lazyOp)
39+
}
3840
lazyOpsCache.removeAll()
3941
}
4042

43+
convenience init(_ lazyOp: LazyTensorOperation) {
44+
self.init([lazyOp])
45+
}
46+
4147
var signature: String {
4248
let inputsDesc: [String] = inputs.map { input in
4349
let dtypeAttr = input.attributes["dtype"]!

Tests/TensorFlowTests/LazyTensorTraceTests.swift

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ final class LazyTensorTraceTests: XCTestCase {
3333
let b = Tensor<Float>(2.0)
3434
let c = Tensor<Float>(3.0)
3535
let w = a + b * c
36-
XCTAssertEqual(lazyTrace(w)!.description,
36+
XCTAssertEqual(lazyTrace(w).description,
3737
"""
3838
lazyTrace_5() -> (%4) {
3939
%0 = Const[dtype: float, value: 10.0]()
@@ -55,7 +55,7 @@ final class LazyTensorTraceTests: XCTestCase {
5555
let w = a + b + c
5656
let y = w * c
5757
let z = y / (w - c)
58-
XCTAssertEqual(lazyTrace(z)!.description,
58+
XCTAssertEqual(lazyTrace(z).description,
5959
"""
6060
lazyTrace_8() -> (%4, %5, %7) {
6161
%0 = Const[dtype: float, value: 10.0]()
@@ -71,7 +71,7 @@ final class LazyTensorTraceTests: XCTestCase {
7171

7272
// Note that we only pick operations on which the lazy tensor in
7373
// question depends on.
74-
XCTAssertEqual(lazyTrace(y)!.description,
74+
XCTAssertEqual(lazyTrace(y).description,
7575
"""
7676
lazyTrace_6() -> (%4, %5) {
7777
%0 = Const[dtype: float, value: 10.0]()
@@ -84,21 +84,43 @@ final class LazyTensorTraceTests: XCTestCase {
8484
""")
8585
}
8686

87+
func testMultipleTargets() {
88+
let a = Tensor<Float>(1.0)
89+
let b = Tensor<Float>(2.0)
90+
let c = Tensor<Float>(3.0)
91+
let d = Tensor<Float>(4.0)
92+
let w = a + b
93+
let x = c + d
94+
let lazyOps = [w, x].map { self.lazyTensorOperation($0)! }
95+
XCTAssertEqual(LazyTensorTrace(lazyOps).description,
96+
"""
97+
lazyTrace_6() -> (%2, %5) {
98+
%0 = Const[dtype: float, value: 1.0]()
99+
%1 = Const[dtype: float, value: 2.0]()
100+
%2 = Add[T: float](%0, %1)
101+
%3 = Const[dtype: float, value: 3.0]()
102+
%4 = Const[dtype: float, value: 4.0]()
103+
%5 = Add[T: float](%3, %4)
104+
}
105+
""")
106+
}
107+
108+
87109
func testSimpleControlFlow() {
88110
let a = Tensor<Float>(5.0)
89111
let addOrMul = { (useAdd: Bool, a: Tensor<Float>) in
90112
useAdd ? (a + a) : (a * a)
91113
}
92114
let add = addOrMul(/*useAdd:*/true, a)
93-
XCTAssertEqual(lazyTrace(add)!.description,
115+
XCTAssertEqual(lazyTrace(add).description,
94116
"""
95117
lazyTrace_2() -> (%1) {
96118
%0 = Const[dtype: float, value: 5.0]()
97119
%1 = Add[T: float](%0, %0)
98120
}
99121
""")
100122
let mul = addOrMul(/*useAdd:*/false, a)
101-
XCTAssertEqual(lazyTrace(mul)!.description,
123+
XCTAssertEqual(lazyTrace(mul).description,
102124
"""
103125
lazyTrace_2() -> (%1) {
104126
%0 = Const[dtype: float, value: 5.0]()
@@ -115,7 +137,7 @@ final class LazyTensorTraceTests: XCTestCase {
115137
// be burnt into the trace as a constant.
116138
let lazyA = a._concreteLazyTensor
117139
let w1 = lazyA * b
118-
let w1Trace = lazyTrace(w1)!
140+
let w1Trace = lazyTrace(w1)
119141
XCTAssertEqual(w1Trace.description,
120142
"""
121143
lazyTrace_3() -> (%2) {
@@ -130,7 +152,7 @@ final class LazyTensorTraceTests: XCTestCase {
130152
// be promoted to an input for the trace.
131153
let inputLazyA = a._concreteInputLazyTensor
132154
let w2 = inputLazyA * b
133-
let w2Trace = lazyTrace(w2)!
155+
let w2Trace = lazyTrace(w2)
134156
XCTAssertEqual(w2Trace.description,
135157
"""
136158
lazyTrace_3(%0: float) -> (%2) {
@@ -151,7 +173,7 @@ final class LazyTensorTraceTests: XCTestCase {
151173
let z = y * c
152174

153175
XCTAssertEqual(
154-
lazyTrace(y)!.description,
176+
lazyTrace(y).description,
155177
"""
156178
lazyTrace_3() -> (%2) {
157179
%0 = Const[dtype: float, value: 1.0]()
@@ -163,7 +185,7 @@ final class LazyTensorTraceTests: XCTestCase {
163185

164186
/// Now that `y` is materialized and a constant,
165187
/// the trace for `z` will use that as a constant.
166-
let zTrace = lazyTrace(z)!
188+
let zTrace = lazyTrace(z)
167189
XCTAssertEqual(
168190
zTrace.description,
169191
"""
@@ -178,9 +200,9 @@ final class LazyTensorTraceTests: XCTestCase {
178200
XCTAssertEqual(z.scalarized(), 9.0)
179201
}
180202

181-
private func lazyTrace<T: TensorFlowScalar>(
203+
private func lazyTensorOperation<T: TensorFlowScalar>(
182204
_ input: Tensor<T>
183-
) -> LazyTensorTrace? {
205+
) -> LazyTensorOperation? {
184206
let tensor = input.handle.handle
185207
guard let lazyTensor = tensor as? LazyTensorHandle else {
186208
XCTFail("Trying to get lazy trace for a non-lazy tensor.")
@@ -190,12 +212,17 @@ final class LazyTensorTraceTests: XCTestCase {
190212
XCTFail("Cannot get lazy trace for a concrete tensor.")
191213
return nil
192214
}
193-
return LazyTensorTrace(lazyOp)
215+
return lazyOp
216+
}
217+
218+
private func lazyTrace<T: TensorFlowScalar>(_ input: Tensor<T>) -> LazyTensorTrace {
219+
return LazyTensorTrace(lazyTensorOperation(input)!)
194220
}
195221

196222
static var allTests = [
197223
("testSingleLiveTensor", testSingleLiveTensor),
198224
("testMultipleLiveTensors", testMultipleLiveTensors),
225+
("testMultipleTargets", testMultipleTargets),
199226
("testSimpleControlFlow", testSimpleControlFlow),
200227
("testManualConstPromotion", testManualConstPromotion),
201228
("testConstPromotion", testConstPromotion)

0 commit comments

Comments
 (0)