Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ef48ae9

Browse files
authored
Adds a mechanism to extract traces explicitly using LazyTensor. (#381)
1 parent 97cb096 commit ef48ae9

File tree

3 files changed

+219
-11
lines changed

3 files changed

+219
-11
lines changed

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,55 @@ class LazyTensorTraceBuilder {
8989
return materializationTraceInfo([lazyOperation])
9090
}
9191

92+
/// Returns a trace obtained by tracing the given function.
93+
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
94+
precondition(_RuntimeConfig.useLazyTensor, "Lazy tensor is not enabled for tracing.")
95+
96+
// Set up inputs for running `fn`.
97+
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }
98+
let inputHandles = inputOps.map { LazyTensorHandle(_lazy: $0, index: 0) }
99+
let input = In(_handles: inputHandles)
100+
101+
// Run the function.
102+
let output: TensorArrayProtocol = fn(input)
103+
104+
// Set up the closure that determines if a `LazyTensorOperation` should be an output.
105+
let outputLazyOperations = output._tensorHandles.map {
106+
(handle: _AnyTensorHandle) -> LazyTensorOperation in
107+
let lazyOp = lazyTensorOperation(handle)
108+
precondition(lazyOp != nil, "Found a non-lazy tensor in output when tracing.")
109+
return lazyOp!
110+
}
111+
let outputIDs = Set<ObjectIdentifier>(
112+
outputLazyOperations.lazy.map { ObjectIdentifier($0) })
113+
114+
// Create the builder and get the trace.
115+
let builder = LazyTensorTraceBuilder()
116+
builder.neverPromoteConstants = true
117+
builder.isOutput = { outputIDs.contains(ObjectIdentifier($0)) }
118+
// Set up the inputs for the builder as we need to have them in a specific order.
119+
for inputOp in inputOps {
120+
builder.updateOperationAndCache(ObjectIdentifier(inputOp), inputOp)
121+
}
122+
builder.inputs = inputOps
123+
for lazyOp in outputLazyOperations { _ = builder.collectLazyOperation(lazyOp) }
124+
return LazyTensorTrace(
125+
inputs: builder.inputs,
126+
operations: builder.operations,
127+
outputs: builder.outputs)
128+
}
129+
92130
// inputs will be "placeholder" nodes.
93131
private var inputs: [LazyTensorOperation] = []
94132
private var inputValues: [TFETensorHandle] = []
95133
private var operations: [LazyTensorOperation] = []
96134
private var outputs: [LazyTensorOperation] = []
97135
private var originalOutputs: [LazyTensorOperation] = []
98136
private var lazyOpsCache: [ObjectIdentifier: LazyTensorOperation] = [:]
137+
/// A flag that controls promotion of constants to inputs.
138+
private var neverPromoteConstants: Bool = false
139+
/// A closure that determines whether a `LazyTensorOperation` is an output.
140+
private var isOutput: (LazyTensorOperation) -> Bool = LazyTensorHandle.isLive
99141

100142
private func updateOperationAndCache(
101143
_ id: ObjectIdentifier, _ node: LazyTensorOperation
@@ -117,14 +159,25 @@ class LazyTensorTraceBuilder {
117159
return LazyTensorHandle(_lazy: result, index: 0)
118160
}
119161

120-
private func makePlaceholderTensor(
121-
with handle: TFETensorHandle
122-
) -> LazyTensorHandle {
123-
let cTensorHandle = handle._cTensorHandle
124-
let dtype = TensorDataType(TFE_TensorHandleDataType(cTensorHandle))
125-
let dtypeAttr = LazyTensorOperation.Attribute.tensorDataTypeValue(dtype)
162+
/// Returns the `LazyTensorOperation`, if any, for this handle.
163+
private static func lazyTensorOperation(_ handle: _AnyTensorHandle) -> LazyTensorOperation? {
164+
guard case let .symbolic(lazyOp, _, _)? = (handle as? LazyTensorHandle)?.handle else {
165+
return nil
166+
}
167+
return lazyOp
168+
}
169+
170+
private static func makePlaceholder(dataType: TensorDataType) -> LazyTensorOperation {
126171
let placeholder = LazyTensorOperation("Placeholder", 1)
172+
let dtypeAttr = LazyTensorOperation.Attribute.tensorDataTypeValue(dataType)
127173
placeholder.attributes = ["dtype": dtypeAttr]
174+
return placeholder
175+
}
176+
177+
private func makePlaceholderTensor(handle: TFETensorHandle) -> LazyTensorHandle {
178+
let cTensorHandle = handle._cTensorHandle
179+
let dtype = TensorDataType(TFE_TensorHandleDataType(cTensorHandle))
180+
let placeholder = Self.makePlaceholder(dataType: dtype)
128181
updateOperationAndCache(ObjectIdentifier(handle), placeholder)
129182
inputs.append(placeholder)
130183
inputValues.append(handle)
@@ -138,12 +191,12 @@ class LazyTensorTraceBuilder {
138191
if let lazyOp = lazyOpsCache[id] {
139192
return LazyTensorHandle(_lazy: lazyOp, index: 0)
140193
}
141-
return asConst
194+
return asConst || neverPromoteConstants
142195
? makeConstTensor(with: handle)
143-
: makePlaceholderTensor(with: handle)
196+
: makePlaceholderTensor(handle: handle)
144197
}
145198

146-
/// Return the original tensor or a concrete tensor that is promoted to a
199+
/// Returns the original tensor or a concrete tensor that is promoted to a
147200
/// placeholder input.
148201
private func maybePromotedTensor(_ lazyHandle: LazyTensorHandle) -> LazyTensorHandle {
149202
switch lazyHandle.handle {
@@ -178,13 +231,15 @@ class LazyTensorTraceBuilder {
178231
if let cachedLazyOp = lazyOpsCache[id] {
179232
return cachedLazyOp
180233
}
181-
234+
precondition(
235+
lazyOp.name != "Placeholder",
236+
"The operation cannot already be a placeholder.")
182237
let newLazyOp = LazyTensorOperation(lazyOp.name, lazyOp.outputCount)
183238
newLazyOp.attributes = lazyOp.attributes
184239
newLazyOp.inputs = lazyOp.inputs.map { maybePromotedInput($0) }
185240
updateOperationAndCache(id, newLazyOp)
186241

187-
if LazyTensorHandle.isLive(lazyOp) {
242+
if isOutput(lazyOp) {
188243
outputs.append(newLazyOp)
189244
originalOutputs.append(lazyOp)
190245
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import TensorFlow
18+
import CTensorFlow
19+
20+
final class LazyTensorExplicitTraceTests: XCTestCase {
21+
override class func setUp() {
22+
super.setUp()
23+
_RuntimeConfig.useLazyTensor = true
24+
}
25+
26+
override class func tearDown() {
27+
super.tearDown()
28+
_RuntimeConfig.useLazyTensor = false
29+
}
30+
31+
func testSingleInput() {
32+
func fn(x: Tensor<Float>) -> Tensor<Float> { return x + x }
33+
let trace = LazyTensorTraceBuilder.trace(fn)
34+
XCTAssertEqual(trace.description,
35+
"""
36+
lazyTrace_2(%0: float) -> (%1) {
37+
%1 = Add[T: float](%0, %0)
38+
}
39+
""")
40+
let outputs = runTrace(trace: trace, input: Tensor<Float>(10.0))
41+
XCTAssertEqual(outputs.count, 1)
42+
XCTAssertEqual(outputs[0].valueDescription, "20.0")
43+
}
44+
45+
func testTensorGroupInputOutputs() {
46+
typealias TensorFloatInt32Pair = Zip2TensorGroup<Tensor<Float>, Tensor<Int32>>
47+
typealias TensorInt32FloatPair = Zip2TensorGroup<Tensor<Int32>, Tensor<Float>>
48+
func fn(input: TensorFloatInt32Pair) -> TensorInt32FloatPair {
49+
return TensorInt32FloatPair(input.second * 4, input.first + 3.0)
50+
}
51+
let trace = LazyTensorTraceBuilder.trace(fn)
52+
XCTAssertEqual(trace.description,
53+
"""
54+
lazyTrace_6(%0: float, %1: int32) -> (%3, %5) {
55+
%2 = Const[dtype: int32, value: 4]()
56+
%3 = Mul[T: int32](%1, %2)
57+
%4 = Const[dtype: float, value: 3.0]()
58+
%5 = Add[T: float](%0, %4)
59+
}
60+
""")
61+
let outputs = runTrace(
62+
trace: trace,
63+
input: TensorFloatInt32Pair(Tensor<Float>(10.0), Tensor<Int32>(5)))
64+
XCTAssertEqual(outputs.count, 2)
65+
XCTAssertEqual(outputs[0].valueDescription, "20")
66+
XCTAssertEqual(outputs[1].valueDescription, "13.0")
67+
}
68+
69+
func testClosureCapturesOfTensors() {
70+
let x = Tensor<Float>(10.0)
71+
let y = x + x
72+
func fn(input: Tensor<Float>) -> Tensor<Float> {
73+
return input * y
74+
}
75+
let trace = LazyTensorTraceBuilder.trace(fn)
76+
/// Note that the computation x + x is encoded in the trace.
77+
XCTAssertEqual(trace.description,
78+
"""
79+
lazyTrace_4(%0: float) -> (%3) {
80+
%1 = Const[dtype: float, value: 10.0]()
81+
%2 = Add[T: float](%1, %1)
82+
%3 = Mul[T: float](%0, %2)
83+
}
84+
""")
85+
let outputs = runTrace(
86+
trace: trace,
87+
input: Tensor<Float>(5.0))
88+
XCTAssertEqual(outputs.count, 1)
89+
XCTAssertEqual(outputs[0].valueDescription, "100.0")
90+
}
91+
92+
func testClosureCapturesOfNonTensors() {
93+
let x: Float = 5.0
94+
func fn(input: Tensor<Float>) -> Tensor<Float> {
95+
return input * Tensor<Float>(x)
96+
}
97+
let trace = LazyTensorTraceBuilder.trace(fn)
98+
/// Note that the computation x + x is encoded in the trace.
99+
XCTAssertEqual(trace.description,
100+
"""
101+
lazyTrace_3(%0: float) -> (%2) {
102+
%1 = Const[dtype: float, value: 5.0]()
103+
%2 = Mul[T: float](%0, %1)
104+
}
105+
""")
106+
let outputs = runTrace(trace: trace, input: Tensor<Float>(23.0))
107+
XCTAssertEqual(outputs.count, 1)
108+
XCTAssertEqual(outputs[0].valueDescription, "115.0")
109+
}
110+
111+
func testNestedTracing() {
112+
func square(input: Tensor<Float>) -> Tensor<Float> {
113+
return input * input
114+
}
115+
116+
func nestedTrace(input: Tensor<Float>) -> Tensor<Float> {
117+
let trace = LazyTensorTraceBuilder.trace(square)
118+
let outputs = runTrace(trace: trace, input: Tensor<Float>(3.0))
119+
XCTAssertEqual(outputs.count, 1)
120+
let handle = TensorHandle<Float>(handle: outputs[0])
121+
let y = Tensor<Float>(handle: handle)
122+
return y + input
123+
}
124+
125+
let trace = LazyTensorTraceBuilder.trace(nestedTrace)
126+
XCTAssertEqual(trace.description,
127+
"""
128+
lazyTrace_3(%0: float) -> (%2) {
129+
%1 = Const[dtype: float, value: 9.0]()
130+
%2 = Add[T: float](%1, %0)
131+
}
132+
""")
133+
let outputs = runTrace(trace: trace, input: Tensor<Float>(4.0))
134+
XCTAssertEqual(outputs.count, 1)
135+
XCTAssertEqual(outputs[0].valueDescription, "13.0")
136+
}
137+
138+
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
139+
let tffunc = TFFunction(trace: trace)
140+
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
141+
let outputHandles = tffunc.execute(inputHandles)
142+
return outputHandles
143+
}
144+
145+
static var allTests = [
146+
("testSingleInput", testSingleInput),
147+
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
148+
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
149+
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors),
150+
("testNestedTracing", testNestedTracing)
151+
]
152+
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public func allTests() -> [XCTestCaseEntry] {
3131
testCase(MathOperatorTests.allTests),
3232
testCase(LazyTensorTests.allTests),
3333
testCase(LazyTensorTraceTests.allTests),
34+
testCase(LazyTensorExplicitTraceTests.allTests),
3435
testCase(LazyTensorOperationTests.allTests),
3536
testCase(LazyTensorTFFunctionBuilderTests.allTests),
3637
testCase(LazyTensorEvaluationTests.allTests),

0 commit comments

Comments
 (0)