Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 27f13d8

Browse files
committed
Putting LazyTensor components together.
Extracted traces are evaluated and used for materializing lazy tensors.
1 parent ebba1e5 commit 27f13d8

File tree

3 files changed

+169
-3
lines changed

3 files changed

+169
-3
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ class LazyTensor: _AnyTensorHandle {
3333
switch handle {
3434
case .concrete(let h, _):
3535
return h
36-
case .symbolic(_, _, _):
37-
fatalError("TODO: to be send out in a separate PR.")
38-
// return op.materialized(index: index)
36+
case .symbolic(let op, let index, _):
37+
return op.materialized(index: index)
3938
}
4039
}
4140

@@ -725,3 +724,69 @@ extension LazyTensorOperation: CustomStringConvertible {
725724
return desc
726725
}
727726
}
727+
728+
extension LazyTensorOperation {
729+
/// Returns the materialized value at the given output `index`.
730+
func materialized(index: Int) -> TFETensorHandle {
731+
return materialized()[index]
732+
}
733+
734+
/// Materializes all the outputs.
735+
func materialized() -> [TFETensorHandle] {
736+
// Return materialized outputs if any.
737+
if let outputs = outputs { return outputs }
738+
739+
LazyTensorOperation.materializeLiveTensors(self)
740+
741+
// Our outputs should have been updated by now. Otherwise,
742+
// something terrible happened!
743+
precondition(outputs != nil, "Materialization failed!")
744+
return outputs!
745+
}
746+
747+
/// Converts symbolic tensor inputs to concrete inputs if the
748+
/// associated `LazyTensorOperation` has been materialized.
749+
private func maybeMaterializeInputs() {
750+
func maybeMaterialized(lazyTensor: LazyTensor) -> LazyTensor {
751+
let handle = lazyTensor.handle
752+
if case let LazyTensor.Handle.symbolic(lazyOp, index, _) = handle {
753+
if let outputs = lazyOp.outputs {
754+
return LazyTensor(_materialized: outputs[index])
755+
}
756+
}
757+
return lazyTensor
758+
}
759+
760+
func maybeMaterialized(input: Input) -> Input {
761+
switch input {
762+
case .single(let h):
763+
return Input.single(maybeMaterialized(lazyTensor: h))
764+
case .list(let elements):
765+
return Input.list(elements.map { maybeMaterialized(lazyTensor: $0) })
766+
}
767+
}
768+
inputs = inputs.map { maybeMaterialized(input: $0) }
769+
}
770+
771+
private static func materializeLiveTensors(_ lazyOp: LazyTensorOperation) {
772+
let lazyTrace = LazyTensorTrace(lazyOp)
773+
debugLog("Extracted trace:\n\(lazyTrace)")
774+
775+
let function = TFFunction(trace: lazyTrace)
776+
debugLog("Generated TFFunction:\n\(function)")
777+
778+
let allOutputs = function.execute(lazyTrace.inputValues)
779+
780+
// Slice up the outputs to various lazy tensors
781+
var start: Int = 0
782+
for lazyOp in lazyTrace.originalOutputs {
783+
let end = start + lazyOp.outputCount
784+
lazyOp.outputs = Array(allOutputs[start..<end])
785+
start = end
786+
}
787+
788+
// On all the live operations rewrite the inputs so that we drop references
789+
// to the LazyTensorOperations.
790+
LazyTensor.forEachOperation { $0.maybeMaterializeInputs() }
791+
}
792+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import TensorFlow
18+
import CTensorFlow
19+
20+
final class LazyTensorEvaluationTests: XCTestCase {
21+
override class func setUp() {
22+
super.setUp()
23+
_RuntimeConfig.useLazyTensor = true
24+
}
25+
26+
func testSimpleOperations() {
27+
let a = Tensor<Float>(10.0)
28+
let b = Tensor<Float>(2.0)
29+
let c = Tensor<Float>(3.0)
30+
let w = a + b * c
31+
32+
XCTAssertFalse(isMaterialized(w))
33+
XCTAssertEqual(w.scalarized(), 16.0)
34+
XCTAssertTrue(isMaterialized(w))
35+
}
36+
37+
func testMultipleMaterializations() {
38+
let a = Tensor<Float>(10.0)
39+
let b = Tensor<Float>(2.0)
40+
let c = Tensor<Float>(3.0)
41+
let x = a + b + c
42+
let y = x * c
43+
let z = y / (x - c)
44+
45+
// Materialize y first
46+
XCTAssertFalse(isMaterialized(x))
47+
XCTAssertFalse(isMaterialized(y))
48+
XCTAssertFalse(isMaterialized(z))
49+
XCTAssertEqual(y.scalarized(), 45.0)
50+
51+
// x and y are materialized, but not z.
52+
XCTAssertTrue(isMaterialized(x))
53+
XCTAssertTrue(isMaterialized(y))
54+
XCTAssertFalse(isMaterialized(z))
55+
56+
XCTAssertEqual(z.scalarized(), 3.75)
57+
XCTAssertTrue(isMaterialized(z))
58+
}
59+
60+
func testSimpleControlFlow() {
61+
let a = Tensor<Float>(5.0)
62+
let addOrMul = { (useAdd: Bool, a: Tensor<Float>) in
63+
useAdd ? (a + a) : (a * a)
64+
}
65+
let add = addOrMul(/*useAdd:*/true, a)
66+
XCTAssertFalse(isMaterialized(add))
67+
XCTAssertEqual(add.scalarized(), 10.0);
68+
XCTAssertTrue(isMaterialized(add))
69+
70+
let mul = addOrMul(/*useAdd:*/false, a)
71+
XCTAssertFalse(isMaterialized(mul))
72+
XCTAssertEqual(mul.scalarized(), 25.0);
73+
XCTAssertTrue(isMaterialized(mul))
74+
}
75+
76+
func testSimpleLoop() {
77+
var sum = Tensor<Float>(0)
78+
for i in 1...10 { sum += Float(i) }
79+
XCTAssertFalse(isMaterialized(sum))
80+
XCTAssertEqual(sum.scalarized(), 55.0, accuracy: 0.00001)
81+
XCTAssertTrue(isMaterialized(sum))
82+
}
83+
84+
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {
85+
let tensor = input.handle.handle
86+
guard let lazyTensor = tensor as? LazyTensor else { return true }
87+
switch lazyTensor.handle {
88+
case .symbolic(let op, _, _): return op.outputs != nil
89+
default: return false
90+
}
91+
}
92+
93+
static var allTests = [
94+
("testSimpleOperations", testSimpleOperations),
95+
("testMultipleMaterializations", testMultipleMaterializations),
96+
("testSimpleControlFlow", testSimpleControlFlow),
97+
("testSimpleLoop", testSimpleLoop),
98+
]
99+
}
100+

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public func allTests() -> [XCTestCaseEntry] {
3232
testCase(LazyTensorTraceTests.allTests),
3333
testCase(LazyTensorOperationTests.allTests),
3434
testCase(LazyTensorTFFunctionBuilderTests.allTests),
35+
testCase(LazyTensorEvaluationTests.allTests),
3536
]
3637
}
3738
#endif

0 commit comments

Comments
 (0)