Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1de468a

Browse files
authored
Rewrite tracing utilities in terms of LazyTensor. (#385)
1 parent 719d7fa commit 1de468a

File tree

6 files changed

+41
-57
lines changed

6 files changed

+41
-57
lines changed

Sources/TensorFlow/Core/LazyTensorTFFunctionBuilder.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class TFFunction {
221221
let cTFFunction: CTFFunction
222222
let outputCount: Int
223223
let outputGroupCounts: [Int]
224+
var name: String { String(cString: TF_FunctionName(cTFFunction)!) }
224225

225226
init(trace: LazyTensorTrace, name: String? = nil) {
226227
let status: CTFStatus = TF_NewStatus()
@@ -268,7 +269,7 @@ class TFFunction {
268269
checkOk(status)
269270
}
270271

271-
func execute(_ inputs: [TFETensorHandle]) -> [TFETensorHandle] {
272+
func execute(_ inputs: [TFETensorHandle], usingXLA: Bool = false) -> [TFETensorHandle] {
272273
let status: CTFStatus = TF_NewStatus()
273274
defer { TF_DeleteStatus(status) }
274275

@@ -285,6 +286,11 @@ class TFFunction {
285286
checkOk(status)
286287
}
287288

289+
if usingXLA {
290+
debugLog("Enabling XLA compilation")
291+
TFE_OpSetAttrBool(eagerOp, "_XlaCompile", 1)
292+
}
293+
288294
for input in inputs {
289295
TFE_OpAddInput(eagerOp, input._cTensorHandle, status)
290296
checkOk(status)

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -934,65 +934,33 @@ public func _tffunc<State: _TensorArrayProtocolEnhanced, Data: TensorGroup>(
934934
}
935935
}
936936

937+
internal func _trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> TFFunction {
938+
let useLazyTensor = _ThreadLocalState.useLazyTensor
939+
defer { _ThreadLocalState.useLazyTensor = useLazyTensor }
940+
_ThreadLocalState.useLazyTensor = true
941+
let trace = LazyTensorTraceBuilder.trace(fn)
942+
return TFFunction(trace: trace)
943+
}
944+
937945
// Trace the given function to generate a TF graph and return a closure that can be used to launch
938946
// the graph.
939947
public func _graph<In: TensorGroup, Out: TensorGroup>(
940948
_ fn: (In) -> Out,
941949
useXLA: Bool = false
942950
) -> (In) -> Out {
943-
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
944-
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
945-
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
946-
var ptr = buffer
947-
for input in inputs {
948-
ptr.initialize(to: input)
949-
ptr = ptr.advanced(by: 1)
950-
}
951-
let symbolicIn = In(_owning: buffer)
952-
let symbolicOut = escapableFn(symbolicIn)
953-
return symbolicOut.cTensorHandles
954-
}
955-
let dtypes = In._typeList.map { $0._cDataType }
956-
return _trace(with: dtypes, in: wrappedFn)
957-
}
958-
// The result is a closure that captures and executes the trace graph function in the trace
959-
// context.
960-
return { (input: In) -> (Out) in
961-
debugLog("Running trace function over input \(input).")
962-
963-
debugLog("Getting input state tensor handles.")
964-
let inputStateTensorHandles = input.cTensorHandles
965-
let inputTensors = inputStateTensorHandles.map {
966-
TFETensorHandle(_owning: $0)
967-
}
968-
debugLog("Executing trace graph function.")
969-
let returnValues = traceContext.execute(traceeInputs: inputTensors, useXLA: useXLA)
970-
971-
debugLog("Creating output model instance.")
972-
return Out(_owning: returnValues)
951+
let tffunc = _trace(fn)
952+
return {input in
953+
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
954+
let outputHandles = tffunc.execute(inputHandles, usingXLA: useXLA)
955+
return Out(_handles: outputHandles)
973956
}
974957
}
975958

976959
/// Trace the given function and return the name of the corresponding `TF_Function: In -> Out` that
977960
/// was created.
978961
public func _tffunc<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> String {
979-
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
980-
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
981-
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
982-
var ptr = buffer
983-
for input in inputs {
984-
ptr.initialize(to: input)
985-
ptr = ptr.advanced(by: 1)
986-
}
987-
let symbolicIn = In(_owning: buffer)
988-
let symbolicOut = escapableFn(symbolicIn)
989-
return symbolicOut.cTensorHandles
990-
}
991-
992-
let dtypes = In._typeList.map { $0._cDataType }
993-
return _trace(with: dtypes, in: wrappedFn)
994-
}
995-
return traceContext.specializeTFFunction(with: [])
962+
let tffunc = _trace(fn)
963+
return tffunc.name
996964
}
997965

998966
internal extension _ExecutionContext {

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import CTensorFlow
2020
final class LazyTensorEvaluationTests: XCTestCase {
2121
override class func setUp() {
2222
super.setUp()
23-
_RuntimeConfig.useLazyTensor = true
23+
_ThreadLocalState.useLazyTensor = true
2424
}
2525

2626
override class func tearDown() {
2727
super.tearDown()
28-
_RuntimeConfig.useLazyTensor = false
28+
_ThreadLocalState.useLazyTensor = false
2929
}
3030

3131
func testSimpleOperations() {

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import CTensorFlow
2020
final class LazyTensorExplicitTraceTests: XCTestCase {
2121
override class func setUp() {
2222
super.setUp()
23-
_RuntimeConfig.useLazyTensor = true
23+
_ThreadLocalState.useLazyTensor = true
2424
}
2525

2626
override class func tearDown() {
2727
super.tearDown()
28-
_RuntimeConfig.useLazyTensor = false
28+
_ThreadLocalState.useLazyTensor = false
2929
}
3030

3131
func testSingleInput() {
@@ -135,6 +135,15 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
135135
XCTAssertEqual(outputs[0].valueDescription, "13.0")
136136
}
137137

138+
func testCallableTrace() {
139+
func square(input: Tensor<Float>) -> Tensor<Float> {
140+
return input * input
141+
}
142+
let tracedSquare = _graph(square)
143+
XCTAssertEqual(tracedSquare(Tensor<Float>(10.0)).scalarized(), 100.0)
144+
XCTAssertEqual(tracedSquare(Tensor<Float>(5.0)).scalarized(), 25.0)
145+
}
146+
138147
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
139148
let tffunc = TFFunction(trace: trace)
140149
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
@@ -147,6 +156,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
147156
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
148157
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
149158
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors),
150-
("testNestedTracing", testNestedTracing)
159+
("testNestedTracing", testNestedTracing),
160+
("testCallableTrace", testCallableTrace)
151161
]
152162
}

Tests/TensorFlowTests/LazyTensorTFFunctionBuilderTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ import CTensorFlow
1919
final class LazyTensorTFFunctionBuilderTests : XCTestCase {
2020
override class func setUp() {
2121
super.setUp()
22-
_RuntimeConfig.useLazyTensor = true
22+
_ThreadLocalState.useLazyTensor = true
2323
}
2424

2525
override class func tearDown() {
2626
super.tearDown()
27-
_RuntimeConfig.useLazyTensor = false
27+
_ThreadLocalState.useLazyTensor = false
2828
}
2929

3030
func testSingletonInputs() {

Tests/TensorFlowTests/LazyTensorTraceTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import CTensorFlow
2020
final class LazyTensorTraceTests: XCTestCase {
2121
override class func setUp() {
2222
super.setUp()
23-
_RuntimeConfig.useLazyTensor = true
23+
_ThreadLocalState.useLazyTensor = true
2424
}
2525

2626
override class func tearDown() {
2727
super.tearDown()
28-
_RuntimeConfig.useLazyTensor = false
28+
_ThreadLocalState.useLazyTensor = false
2929
}
3030

3131
func testSingleLiveTensor() {

0 commit comments

Comments
 (0)