Skip to content

Add an api to trace an function and run it with XLA compilation. #23868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion stdlib/public/TensorFlow/CompilerRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ private class TraceContext {

/// Execute the trace graph function, and return the list of output tensors
/// from the trace execution. These output tensors are owned by the caller.
func execute(traceeInputs: [_AnyTensorHandle]) -> [CTensorHandle] {
func execute(
traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xla -> XLA

// We must be in the `notTracing` enum mode.
internalConsistencyCheck(_RuntimeConfig.traceState.context == nil)
internalConsistencyCheck(traceGraphFn != nil)
Expand All @@ -236,6 +237,11 @@ private class TraceContext {
checkOk(status)
}

if useXla {
debugLog("Enabling XLA compilation")
TFE_OpSetAttrBool(op, "_XlaCompile", 1)
}

debugLog("Adding \(traceeInputs.count) tracee input tensors.")
internalConsistencyCheck(symbolicInputs.count == traceeInputs.count
+ Int(additionalInputTensorCount))
Expand Down Expand Up @@ -993,6 +999,47 @@ public func _tffunc<State : _TensorArrayProtocolEnhanced,
}
}

// Trace the given function to generate a TF graph and return a closure
// that can be used to launch the graph.
public func _graph<In : TensorGroup, Out : TensorGroup>(
_ fn: (In) -> Out, useXla: Bool = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xla -> XLA

) -> (In) -> Out {
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
let wrappedFn = {
(inputs: [CTensorHandle]) -> [CTensorHandle] in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be moved to the previous line without going over the column limit.

let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(
capacity: Int(inputs.count))
var ptr = buffer
for input in inputs {
ptr.initialize(to: input)
ptr = ptr.advanced(by: 1)
}
let symbolicIn = In(_owning: buffer)
let symbolicOut = escapableFn(symbolicIn)
return symbolicOut.cTensorHandles
}
let dtypes = In._typeList.map { $0._cDataType }
return _trace(with: dtypes, in: wrappedFn)
}
// The result is a closure that captures and executes the trace graph
// function in the trace context.
return { (input: In) -> (Out) in
debugLog("Running trace function over input \(input).")

debugLog("Getting input state tensor handles.")
let inputStateTensorHandles = input.cTensorHandles
let inputTensors = inputStateTensorHandles.map {
_TFCCreateTensorHandleFromC($0)
}
debugLog("Executing trace graph function.")
let returnValues = traceContext.execute(
traceeInputs: inputTensors, useXla: useXla)

debugLog("Creating output model instance.")
return Out(_copying: returnValues)
}
}

/// Trace the given function and return the name of the corresponding
/// `TF_Function: In -> Out` that was created.
public func _tffunc<In : TensorGroup, Out : TensorGroup>(
Expand Down
10 changes: 10 additions & 0 deletions test/TensorFlowRuntime/tracer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ TracerTests.testAllBackends("TraceWithNoResult") {
expectNearlyEqualWithScalarTensor(8.0, tracedAdd(Tensor<Float>(5.0), three))
}

TracerTests.testAllBackends("TracerWithInOut") {
func addOne(state: Tensor<Int32>) -> (Tensor<Int32>) {
return state + 1
}
let addOneGraph = _graph(addOne)
expectEqual(addOneGraph(Tensor<Int32>(5)), Tensor<Int32>(6))
expectEqual(addOneGraph(Tensor<Int32>(0)), Tensor<Int32>(1))
expectEqual(addOneGraph(Tensor<Int32>(-1)), Tensor<Int32>(0))
}

TracerTests.testAllBackends("Basic_IntermediateTensors") {
func tracee(state: Tensor<Float>, data: Data) -> (Tensor<Float>, Result) {
// Create an intermediate tensor value, which the tracing infra needs to
Expand Down