Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 26167ff

Browse files
committed
Allow lazy traces to be extracted from multiple targets.
1 parent 97ca18b commit 26167ff

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class LazyTensorTrace {
2424
var outputs: [LazyTensorOperation] = []
2525
var originalOutputs: [LazyTensorOperation] = []
2626

27-
init(_ lazyOp: LazyTensorOperation) {
27+
init(_ lazyOperations: [LazyTensorOperation]) {
2828
// TODO: We only pick operations on which `lazyOp` depends on. Note that
2929
// there may be other live tensors that could also be materialized at
3030
// this time. e.g.,
@@ -34,10 +34,16 @@ class LazyTensorTrace {
3434
// `y = x + c` into the trace so that we don't have the overhead of creating
3535
// another trace when we need to materialize `y`.
3636
//
37-
_ = collectLazyOperation(lazyOp)
37+
for lazyOp in lazyOperations {
38+
_ = collectLazyOperation(lazyOp)
39+
}
3840
lazyOpsCache.removeAll()
3941
}
4042

43+
convenience init(_ lazyOp: LazyTensorOperation) {
44+
self.init([lazyOp])
45+
}
46+
4147
var signature: String {
4248
let inputsDesc: [String] = inputs.map { input in
4349
let dtypeAttr = input.attributes["dtype"]!

Tests/TensorFlowTests/LazyTensorTraceTests.swift

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ final class LazyTensorTraceTests: XCTestCase {
3333
let b = Tensor<Float>(2.0)
3434
let c = Tensor<Float>(3.0)
3535
let w = a + b * c
36-
XCTAssertEqual(lazyTrace(w)!.description,
36+
XCTAssertEqual(lazyTrace(w).description,
3737
"""
3838
lazyTrace_5() -> (%4) {
3939
%0 = Const[dtype: float, value: 10.0]()
@@ -55,7 +55,7 @@ final class LazyTensorTraceTests: XCTestCase {
5555
let w = a + b + c
5656
let y = w * c
5757
let z = y / (w - c)
58-
XCTAssertEqual(lazyTrace(z)!.description,
58+
XCTAssertEqual(lazyTrace(z).description,
5959
"""
6060
lazyTrace_8() -> (%4, %5, %7) {
6161
%0 = Const[dtype: float, value: 10.0]()
@@ -71,7 +71,7 @@ final class LazyTensorTraceTests: XCTestCase {
7171

7272
// Note that we only pick operations on which the lazy tensor in
7373
// question depends on.
74-
XCTAssertEqual(lazyTrace(y)!.description,
74+
XCTAssertEqual(lazyTrace(y).description,
7575
"""
7676
lazyTrace_6() -> (%4, %5) {
7777
%0 = Const[dtype: float, value: 10.0]()
@@ -84,21 +84,46 @@ final class LazyTensorTraceTests: XCTestCase {
8484
""")
8585
}
8686

87+
func testMultipleTargets() {
88+
// This test checks that *only* the operations that correspond to `w`,
89+
// `y` and `z` are marked as outputs. Specifcally, the intermediate
90+
// operations in the trace are not marked as outputs.
91+
let a = Tensor<Float>(1.0)
92+
let b = Tensor<Float>(2.0)
93+
let c = Tensor<Float>(3.0)
94+
let d = Tensor<Float>(4.0)
95+
let w = a + b
96+
let x = c + d
97+
let lazyOps = [w, x].map { self.lazyTensorOperation($0)! }
98+
XCTAssertEqual(LazyTensorTrace(lazyOps).description,
99+
"""
100+
lazyTrace_6() -> (%2, %5) {
101+
%0 = Const[dtype: float, value: 1.0]()
102+
%1 = Const[dtype: float, value: 2.0]()
103+
%2 = Add[T: float](%0, %1)
104+
%3 = Const[dtype: float, value: 3.0]()
105+
%4 = Const[dtype: float, value: 4.0]()
106+
%5 = Add[T: float](%3, %4)
107+
}
108+
""")
109+
}
110+
111+
87112
func testSimpleControlFlow() {
88113
let a = Tensor<Float>(5.0)
89114
let addOrMul = { (useAdd: Bool, a: Tensor<Float>) in
90115
useAdd ? (a + a) : (a * a)
91116
}
92117
let add = addOrMul(/*useAdd:*/true, a)
93-
XCTAssertEqual(lazyTrace(add)!.description,
118+
XCTAssertEqual(lazyTrace(add).description,
94119
"""
95120
lazyTrace_2() -> (%1) {
96121
%0 = Const[dtype: float, value: 5.0]()
97122
%1 = Add[T: float](%0, %0)
98123
}
99124
""")
100125
let mul = addOrMul(/*useAdd:*/false, a)
101-
XCTAssertEqual(lazyTrace(mul)!.description,
126+
XCTAssertEqual(lazyTrace(mul).description,
102127
"""
103128
lazyTrace_2() -> (%1) {
104129
%0 = Const[dtype: float, value: 5.0]()
@@ -115,7 +140,7 @@ final class LazyTensorTraceTests: XCTestCase {
115140
// be burnt into the trace as a constant.
116141
let lazyA = a._concreteLazyTensor
117142
let w1 = lazyA * b
118-
let w1Trace = lazyTrace(w1)!
143+
let w1Trace = lazyTrace(w1)
119144
XCTAssertEqual(w1Trace.description,
120145
"""
121146
lazyTrace_3() -> (%2) {
@@ -130,7 +155,7 @@ final class LazyTensorTraceTests: XCTestCase {
130155
// be promoted to an input for the trace.
131156
let inputLazyA = a._concreteInputLazyTensor
132157
let w2 = inputLazyA * b
133-
let w2Trace = lazyTrace(w2)!
158+
let w2Trace = lazyTrace(w2)
134159
XCTAssertEqual(w2Trace.description,
135160
"""
136161
lazyTrace_3(%0: float) -> (%2) {
@@ -151,7 +176,7 @@ final class LazyTensorTraceTests: XCTestCase {
151176
let z = y * c
152177

153178
XCTAssertEqual(
154-
lazyTrace(y)!.description,
179+
lazyTrace(y).description,
155180
"""
156181
lazyTrace_3() -> (%2) {
157182
%0 = Const[dtype: float, value: 1.0]()
@@ -163,7 +188,7 @@ final class LazyTensorTraceTests: XCTestCase {
163188

164189
/// Now that `y` is materialized and a constant,
165190
/// the trace for `z` will use that as a constant.
166-
let zTrace = lazyTrace(z)!
191+
let zTrace = lazyTrace(z)
167192
XCTAssertEqual(
168193
zTrace.description,
169194
"""
@@ -178,9 +203,9 @@ final class LazyTensorTraceTests: XCTestCase {
178203
XCTAssertEqual(z.scalarized(), 9.0)
179204
}
180205

181-
private func lazyTrace<T: TensorFlowScalar>(
206+
private func lazyTensorOperation<T: TensorFlowScalar>(
182207
_ input: Tensor<T>
183-
) -> LazyTensorTrace? {
208+
) -> LazyTensorOperation? {
184209
let tensor = input.handle.handle
185210
guard let lazyTensor = tensor as? LazyTensorHandle else {
186211
XCTFail("Trying to get lazy trace for a non-lazy tensor.")
@@ -190,12 +215,17 @@ final class LazyTensorTraceTests: XCTestCase {
190215
XCTFail("Cannot get lazy trace for a concrete tensor.")
191216
return nil
192217
}
193-
return LazyTensorTrace(lazyOp)
218+
return lazyOp
219+
}
220+
221+
private func lazyTrace<T: TensorFlowScalar>(_ input: Tensor<T>) -> LazyTensorTrace {
222+
return LazyTensorTrace(lazyTensorOperation(input)!)
194223
}
195224

196225
static var allTests = [
197226
("testSingleLiveTensor", testSingleLiveTensor),
198227
("testMultipleLiveTensors", testMultipleLiveTensors),
228+
("testMultipleTargets", testMultipleTargets),
199229
("testSimpleControlFlow", testSimpleControlFlow),
200230
("testManualConstPromotion", testManualConstPromotion),
201231
("testConstPromotion", testConstPromotion)

0 commit comments

Comments
 (0)