Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Rewrite tracing utilities in terms of LazyTensor. #385

Merged
merged 9 commits into from
Jul 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Sources/TensorFlow/Core/LazyTensorTFFunctionBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class TFFunction {
let cTFFunction: CTFFunction
let outputCount: Int
let outputGroupCounts: [Int]
var name: String { String(cString: TF_FunctionName(cTFFunction)!) }

init(trace: LazyTensorTrace, name: String? = nil) {
let status: CTFStatus = TF_NewStatus()
Expand Down Expand Up @@ -268,7 +269,7 @@ class TFFunction {
checkOk(status)
}

func execute(_ inputs: [TFETensorHandle]) -> [TFETensorHandle] {
func execute(_ inputs: [TFETensorHandle], usingXLA: Bool = false) -> [TFETensorHandle] {
let status: CTFStatus = TF_NewStatus()
defer { TF_DeleteStatus(status) }

Expand All @@ -285,6 +286,11 @@ class TFFunction {
checkOk(status)
}

if usingXLA {
debugLog("Enabling XLA compilation")
TFE_OpSetAttrBool(eagerOp, "_XlaCompile", 1)
}

for input in inputs {
TFE_OpAddInput(eagerOp, input._cTensorHandle, status)
checkOk(status)
Expand Down
62 changes: 15 additions & 47 deletions Sources/TensorFlow/Core/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -934,65 +934,33 @@ public func _tffunc<State: _TensorArrayProtocolEnhanced, Data: TensorGroup>(
}
}

internal func _trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> TFFunction {
let useLazyTensor = _ThreadLocalState.useLazyTensor
defer { _ThreadLocalState.useLazyTensor = useLazyTensor }
_ThreadLocalState.useLazyTensor = true
let trace = LazyTensorTraceBuilder.trace(fn)
return TFFunction(trace: trace)
}

// Trace the given function to generate a TF graph and return a closure that can be used to launch
// the graph.
public func _graph<In: TensorGroup, Out: TensorGroup>(
_ fn: (In) -> Out,
useXLA: Bool = false
) -> (In) -> Out {
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
var ptr = buffer
for input in inputs {
ptr.initialize(to: input)
ptr = ptr.advanced(by: 1)
}
let symbolicIn = In(_owning: buffer)
let symbolicOut = escapableFn(symbolicIn)
return symbolicOut.cTensorHandles
}
let dtypes = In._typeList.map { $0._cDataType }
return _trace(with: dtypes, in: wrappedFn)
}
// The result is a closure that captures and executes the trace graph function in the trace
// context.
return { (input: In) -> (Out) in
debugLog("Running trace function over input \(input).")

debugLog("Getting input state tensor handles.")
let inputStateTensorHandles = input.cTensorHandles
let inputTensors = inputStateTensorHandles.map {
TFETensorHandle(_owning: $0)
}
debugLog("Executing trace graph function.")
let returnValues = traceContext.execute(traceeInputs: inputTensors, useXLA: useXLA)

debugLog("Creating output model instance.")
return Out(_owning: returnValues)
let tffunc = _trace(fn)
return {input in
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
let outputHandles = tffunc.execute(inputHandles, usingXLA: useXLA)
return Out(_handles: outputHandles)
}
}

/// Trace the given function and return the name of the corresponding `TF_Function: In -> Out` that
/// was created.
public func _tffunc<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> String {
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
var ptr = buffer
for input in inputs {
ptr.initialize(to: input)
ptr = ptr.advanced(by: 1)
}
let symbolicIn = In(_owning: buffer)
let symbolicOut = escapableFn(symbolicIn)
return symbolicOut.cTensorHandles
}

let dtypes = In._typeList.map { $0._cDataType }
return _trace(with: dtypes, in: wrappedFn)
}
return traceContext.specializeTFFunction(with: [])
let tffunc = _trace(fn)
return tffunc.name
}

internal extension _ExecutionContext {
Expand Down
4 changes: 2 additions & 2 deletions Tests/TensorFlowTests/LazyTensorEvaluationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import CTensorFlow
final class LazyTensorEvaluationTests: XCTestCase {
override class func setUp() {
super.setUp()
_RuntimeConfig.useLazyTensor = true
_ThreadLocalState.useLazyTensor = true
}

override class func tearDown() {
super.tearDown()
_RuntimeConfig.useLazyTensor = false
_ThreadLocalState.useLazyTensor = false
}

func testSimpleOperations() {
Expand Down
16 changes: 13 additions & 3 deletions Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import CTensorFlow
final class LazyTensorExplicitTraceTests: XCTestCase {
override class func setUp() {
super.setUp()
_RuntimeConfig.useLazyTensor = true
_ThreadLocalState.useLazyTensor = true
}

override class func tearDown() {
super.tearDown()
_RuntimeConfig.useLazyTensor = false
_ThreadLocalState.useLazyTensor = false
}

func testSingleInput() {
Expand Down Expand Up @@ -135,6 +135,15 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
XCTAssertEqual(outputs[0].valueDescription, "13.0")
}

func testCallableTrace() {
func square(input: Tensor<Float>) -> Tensor<Float> {
return input * input
}
let tracedSquare = _graph(square)
XCTAssertEqual(tracedSquare(Tensor<Float>(10.0)).scalarized(), 100.0)
XCTAssertEqual(tracedSquare(Tensor<Float>(5.0)).scalarized(), 25.0)
}

private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
let tffunc = TFFunction(trace: trace)
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
Expand All @@ -147,6 +156,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors),
("testNestedTracing", testNestedTracing)
("testNestedTracing", testNestedTracing),
("testCallableTrace", testCallableTrace)
]
}
4 changes: 2 additions & 2 deletions Tests/TensorFlowTests/LazyTensorTFFunctionBuilderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ import CTensorFlow
final class LazyTensorTFFunctionBuilderTests : XCTestCase {
override class func setUp() {
super.setUp()
_RuntimeConfig.useLazyTensor = true
_ThreadLocalState.useLazyTensor = true
}

override class func tearDown() {
super.tearDown()
_RuntimeConfig.useLazyTensor = false
_ThreadLocalState.useLazyTensor = false
}

func testSingletonInputs() {
Expand Down
4 changes: 2 additions & 2 deletions Tests/TensorFlowTests/LazyTensorTraceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import CTensorFlow
final class LazyTensorTraceTests: XCTestCase {
override class func setUp() {
super.setUp()
_RuntimeConfig.useLazyTensor = true
_ThreadLocalState.useLazyTensor = true
}

override class func tearDown() {
super.tearDown()
_RuntimeConfig.useLazyTensor = false
_ThreadLocalState.useLazyTensor = false
}

func testSingleLiveTensor() {
Expand Down