Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 76a69b7

Browse files
authored
Putting LazyTensor components together (#294)
Extracted traces are evaluated and used for materializing lazy tensors.
1 parent 94067ef commit 76a69b7

File tree

3 files changed

+181
-3
lines changed

3 files changed

+181
-3
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ class LazyTensor: _AnyTensorHandle {
3333
switch handle {
3434
case .concrete(let h, _):
3535
return h
36-
case .symbolic(_, _, _):
37-
fatalError("TODO: to be send out in a separate PR.")
38-
// return op.materialized(index: index)
36+
case .symbolic(let op, let index, _):
37+
return op.materialized(index: index)
3938
}
4039
}
4140

@@ -725,3 +724,76 @@ extension LazyTensorOperation: CustomStringConvertible {
725724
return desc
726725
}
727726
}
727+
728+
extension LazyTensorOperation {
729+
/// Returns the materialized value at the given output `index`.
730+
func materialized(index: Int) -> TFETensorHandle {
731+
precondition(index < outputCount)
732+
return materialized()[index]
733+
}
734+
735+
/// Materializes all the outputs.
736+
func materialized() -> [TFETensorHandle] {
737+
// Return materialized outputs if any.
738+
if let outputs = outputs { return outputs }
739+
740+
materializeLiveTensors()
741+
742+
// Our outputs should have been updated by now. Otherwise,
743+
// something terrible happened!
744+
precondition(outputs != nil, "Materialization failed!")
745+
return outputs!
746+
}
747+
748+
/// Converts symbolic tensor inputs to concrete inputs if the
749+
/// associated `LazyTensorOperation` has been materialized.
750+
private func maybeMaterializeInputs() {
751+
/// If `lazyTensor` is symbolic and the associated `LazyTensorOperation`
752+
/// has been materialized, return the corresponding concrete `LazyTensor`.
753+
/// Otherwise, return `lazyTensor` untouched.
754+
func materializedAsNeeded(lazyTensor: LazyTensor) -> LazyTensor {
755+
let handle = lazyTensor.handle
756+
if case let LazyTensor.Handle.symbolic(lazyOp, index, _) = handle,
757+
let outputs = lazyOp.outputs {
758+
return LazyTensor(_materialized: outputs[index])
759+
}
760+
return lazyTensor
761+
}
762+
763+
/// Returns an input that is rewritten such that all symbolic values
764+
/// that have been materialized have been replaced by the corresponding
765+
/// concerete inputs. If no symbolic values have been materialized or if
766+
/// there are no symbolic values, return the `input` untouched.
767+
func materializedAsNeeded(input: Input) -> Input {
768+
switch input {
769+
case .single(let h):
770+
return .single(materializedAsNeeded(lazyTensor: h))
771+
case .list(let elements):
772+
return .list(elements.map { materializedAsNeeded(lazyTensor: $0) })
773+
}
774+
}
775+
inputs = inputs.map { materializedAsNeeded(input: $0) }
776+
}
777+
778+
private func materializeLiveTensors() {
779+
let lazyTrace = LazyTensorTrace(self)
780+
debugLog("Extracted trace:\n\(lazyTrace)")
781+
782+
let function = TFFunction(trace: lazyTrace)
783+
debugLog("Generated TFFunction:\n\(function)")
784+
785+
let allOutputs = function.execute(lazyTrace.inputValues)
786+
787+
// Slice up the outputs to various lazy tensors
788+
var start = 0
789+
for lazyOp in lazyTrace.originalOutputs {
790+
let end = start + lazyOp.outputCount
791+
lazyOp.outputs = Array(allOutputs[start..<end])
792+
start = end
793+
}
794+
795+
// On all the live operations rewrite the inputs so that we drop references
796+
// to the LazyTensorOperations.
797+
LazyTensor.forEachOperation { $0.maybeMaterializeInputs() }
798+
}
799+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import TensorFlow
18+
import CTensorFlow
19+
20+
final class LazyTensorEvaluationTests: XCTestCase {
21+
override class func setUp() {
22+
super.setUp()
23+
_RuntimeConfig.useLazyTensor = true
24+
}
25+
26+
override class func tearDown() {
27+
super.tearDown()
28+
_RuntimeConfig.useLazyTensor = false
29+
}
30+
31+
func testSimpleOperations() {
32+
let a = Tensor<Float>(10.0)
33+
let b = Tensor<Float>(2.0)
34+
let c = Tensor<Float>(3.0)
35+
let w = a + b * c
36+
37+
XCTAssertFalse(isMaterialized(w))
38+
XCTAssertEqual(w.scalarized(), 16.0)
39+
XCTAssertTrue(isMaterialized(w))
40+
}
41+
42+
func testMultipleMaterializations() {
43+
let a = Tensor<Float>(10.0)
44+
let b = Tensor<Float>(2.0)
45+
let c = Tensor<Float>(3.0)
46+
let x = a + b + c
47+
let y = x * c
48+
let z = y / (x - c)
49+
50+
// Materialize y first
51+
XCTAssertFalse(isMaterialized(x))
52+
XCTAssertFalse(isMaterialized(y))
53+
XCTAssertFalse(isMaterialized(z))
54+
XCTAssertEqual(y.scalarized(), 45.0)
55+
56+
// x and y are materialized, but not z.
57+
XCTAssertTrue(isMaterialized(x))
58+
XCTAssertTrue(isMaterialized(y))
59+
XCTAssertFalse(isMaterialized(z))
60+
61+
XCTAssertEqual(z.scalarized(), 3.75)
62+
XCTAssertTrue(isMaterialized(z))
63+
}
64+
65+
func testSimpleControlFlow() {
66+
let a = Tensor<Float>(5.0)
67+
let addOrMul = { (useAdd: Bool, a: Tensor<Float>) in
68+
useAdd ? (a + a) : (a * a)
69+
}
70+
let add = addOrMul(/*useAdd:*/true, a)
71+
XCTAssertFalse(isMaterialized(add))
72+
XCTAssertEqual(add.scalarized(), 10.0);
73+
XCTAssertTrue(isMaterialized(add))
74+
75+
let mul = addOrMul(/*useAdd:*/false, a)
76+
XCTAssertFalse(isMaterialized(mul))
77+
XCTAssertEqual(mul.scalarized(), 25.0);
78+
XCTAssertTrue(isMaterialized(mul))
79+
}
80+
81+
func testSimpleLoop() {
82+
var sum = Tensor<Float>(0)
83+
for i in 1...10 { sum += Float(i) }
84+
XCTAssertFalse(isMaterialized(sum))
85+
XCTAssertEqual(sum.scalarized(), 55.0, accuracy: 0.00001)
86+
XCTAssertTrue(isMaterialized(sum))
87+
}
88+
89+
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {
90+
let tensor = input.handle.handle
91+
guard let lazyTensor = tensor as? LazyTensor else { return true }
92+
switch lazyTensor.handle {
93+
case .symbolic(let op, _, _): return op.outputs != nil
94+
default: return false
95+
}
96+
}
97+
98+
static var allTests = [
99+
("testSimpleOperations", testSimpleOperations),
100+
("testMultipleMaterializations", testMultipleMaterializations),
101+
("testSimpleControlFlow", testSimpleControlFlow),
102+
("testSimpleLoop", testSimpleLoop),
103+
]
104+
}
105+

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public func allTests() -> [XCTestCaseEntry] {
3232
testCase(LazyTensorTraceTests.allTests),
3333
testCase(LazyTensorOperationTests.allTests),
3434
testCase(LazyTensorTFFunctionBuilderTests.allTests),
35+
testCase(LazyTensorEvaluationTests.allTests),
3536
]
3637
}
3738
#endif

0 commit comments

Comments
 (0)