Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Allow lazy traces to be extracted from multiple targets. #373

Merged
merged 1 commit into from
Jul 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Sources/TensorFlow/Core/LazyTensorTrace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LazyTensorTrace {
var outputs: [LazyTensorOperation] = []
var originalOutputs: [LazyTensorOperation] = []

init(_ lazyOp: LazyTensorOperation) {
init(_ lazyOperations: [LazyTensorOperation]) {
// TODO: We only pick operations on which `lazyOp` depends on. Note that
// there may be other live tensors that could also be materialized at
// this time. e.g.,
Expand All @@ -34,10 +34,16 @@ class LazyTensorTrace {
// `y = x + c` into the trace so that we don't have the overhead of creating
// another trace when we need to materialize `y`.
//
_ = collectLazyOperation(lazyOp)
for lazyOp in lazyOperations {
_ = collectLazyOperation(lazyOp)
}
lazyOpsCache.removeAll()
}

convenience init(_ lazyOp: LazyTensorOperation) {
self.init([lazyOp])
}

var signature: String {
let inputsDesc: [String] = inputs.map { input in
let dtypeAttr = input.attributes["dtype"]!
Expand Down
51 changes: 39 additions & 12 deletions Tests/TensorFlowTests/LazyTensorTraceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ final class LazyTensorTraceTests: XCTestCase {
let b = Tensor<Float>(2.0)
let c = Tensor<Float>(3.0)
let w = a + b * c
XCTAssertEqual(lazyTrace(w)!.description,
XCTAssertEqual(lazyTrace(w).description,
"""
lazyTrace_5() -> (%4) {
%0 = Const[dtype: float, value: 10.0]()
Expand All @@ -55,7 +55,7 @@ final class LazyTensorTraceTests: XCTestCase {
let w = a + b + c
let y = w * c
let z = y / (w - c)
XCTAssertEqual(lazyTrace(z)!.description,
XCTAssertEqual(lazyTrace(z).description,
"""
lazyTrace_8() -> (%4, %5, %7) {
%0 = Const[dtype: float, value: 10.0]()
Expand All @@ -71,7 +71,7 @@ final class LazyTensorTraceTests: XCTestCase {

// Note that we only pick operations on which the lazy tensor in
// question depends on.
XCTAssertEqual(lazyTrace(y)!.description,
XCTAssertEqual(lazyTrace(y).description,
"""
lazyTrace_6() -> (%4, %5) {
%0 = Const[dtype: float, value: 10.0]()
Expand All @@ -84,21 +84,43 @@ final class LazyTensorTraceTests: XCTestCase {
""")
}

func testMultipleTargets() {
let a = Tensor<Float>(1.0)
let b = Tensor<Float>(2.0)
let c = Tensor<Float>(3.0)
let d = Tensor<Float>(4.0)
let w = a + b
let x = c + d
let lazyOps = [w, x].map { self.lazyTensorOperation($0)! }
XCTAssertEqual(LazyTensorTrace(lazyOps).description,
"""
lazyTrace_6() -> (%2, %5) {
%0 = Const[dtype: float, value: 1.0]()
%1 = Const[dtype: float, value: 2.0]()
%2 = Add[T: float](%0, %1)
%3 = Const[dtype: float, value: 3.0]()
%4 = Const[dtype: float, value: 4.0]()
%5 = Add[T: float](%3, %4)
}
""")
}


func testSimpleControlFlow() {
let a = Tensor<Float>(5.0)
let addOrMul = { (useAdd: Bool, a: Tensor<Float>) in
useAdd ? (a + a) : (a * a)
}
let add = addOrMul(/*useAdd:*/true, a)
XCTAssertEqual(lazyTrace(add)!.description,
XCTAssertEqual(lazyTrace(add).description,
"""
lazyTrace_2() -> (%1) {
%0 = Const[dtype: float, value: 5.0]()
%1 = Add[T: float](%0, %0)
}
""")
let mul = addOrMul(/*useAdd:*/false, a)
XCTAssertEqual(lazyTrace(mul)!.description,
XCTAssertEqual(lazyTrace(mul).description,
"""
lazyTrace_2() -> (%1) {
%0 = Const[dtype: float, value: 5.0]()
Expand All @@ -115,7 +137,7 @@ final class LazyTensorTraceTests: XCTestCase {
// be burnt into the trace as a constant.
let lazyA = a._concreteLazyTensor
let w1 = lazyA * b
let w1Trace = lazyTrace(w1)!
let w1Trace = lazyTrace(w1)
XCTAssertEqual(w1Trace.description,
"""
lazyTrace_3() -> (%2) {
Expand All @@ -130,7 +152,7 @@ final class LazyTensorTraceTests: XCTestCase {
// be promoted to an input for the trace.
let inputLazyA = a._concreteInputLazyTensor
let w2 = inputLazyA * b
let w2Trace = lazyTrace(w2)!
let w2Trace = lazyTrace(w2)
XCTAssertEqual(w2Trace.description,
"""
lazyTrace_3(%0: float) -> (%2) {
Expand All @@ -151,7 +173,7 @@ final class LazyTensorTraceTests: XCTestCase {
let z = y * c

XCTAssertEqual(
lazyTrace(y)!.description,
lazyTrace(y).description,
"""
lazyTrace_3() -> (%2) {
%0 = Const[dtype: float, value: 1.0]()
Expand All @@ -163,7 +185,7 @@ final class LazyTensorTraceTests: XCTestCase {

/// Now that `y` is materialized and a constant,
/// the trace for `z` will use that as a constant.
let zTrace = lazyTrace(z)!
let zTrace = lazyTrace(z)
XCTAssertEqual(
zTrace.description,
"""
Expand All @@ -178,9 +200,9 @@ final class LazyTensorTraceTests: XCTestCase {
XCTAssertEqual(z.scalarized(), 9.0)
}

private func lazyTrace<T: TensorFlowScalar>(
private func lazyTensorOperation<T: TensorFlowScalar>(
_ input: Tensor<T>
) -> LazyTensorTrace? {
) -> LazyTensorOperation? {
let tensor = input.handle.handle
guard let lazyTensor = tensor as? LazyTensorHandle else {
XCTFail("Trying to get lazy trace for a non-lazy tensor.")
Expand All @@ -190,12 +212,17 @@ final class LazyTensorTraceTests: XCTestCase {
XCTFail("Cannot get lazy trace for a concrete tensor.")
return nil
}
return LazyTensorTrace(lazyOp)
return lazyOp
}

private func lazyTrace<T: TensorFlowScalar>(_ input: Tensor<T>) -> LazyTensorTrace {
return LazyTensorTrace(lazyTensorOperation(input)!)
}

static var allTests = [
("testSingleLiveTensor", testSingleLiveTensor),
("testMultipleLiveTensors", testMultipleLiveTensors),
("testMultipleTargets", testMultipleTargets),
("testSimpleControlFlow", testSimpleControlFlow),
("testManualConstPromotion", testManualConstPromotion),
("testConstPromotion", testConstPromotion)
Expand Down