Skip to content

Commit b71fb98

Browse files
authored
Add an api to trace an function and run it with XLA compilation. (#23868)
1 parent 6deffd0 commit b71fb98

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ private class TraceContext {
218218

219219
/// Execute the trace graph function, and return the list of output tensors
220220
/// from the trace execution. These output tensors are owned by the caller.
221-
func execute(traceeInputs: [_AnyTensorHandle]) -> [CTensorHandle] {
221+
func execute(
222+
traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] {
222223
// We must be in the `notTracing` enum mode.
223224
internalConsistencyCheck(_RuntimeConfig.traceState.context == nil)
224225
internalConsistencyCheck(traceGraphFn != nil)
@@ -236,6 +237,11 @@ private class TraceContext {
236237
checkOk(status)
237238
}
238239

240+
if useXla {
241+
debugLog("Enabling XLA compilation")
242+
TFE_OpSetAttrBool(op, "_XlaCompile", 1)
243+
}
244+
239245
debugLog("Adding \(traceeInputs.count) tracee input tensors.")
240246
internalConsistencyCheck(symbolicInputs.count == traceeInputs.count
241247
+ Int(additionalInputTensorCount))
@@ -993,6 +999,47 @@ public func _tffunc<State : _TensorArrayProtocolEnhanced,
993999
}
9941000
}
9951001

1002+
// Trace the given function to generate a TF graph and return a closure
1003+
// that can be used to launch the graph.
1004+
public func _graph<In : TensorGroup, Out : TensorGroup>(
1005+
_ fn: (In) -> Out, useXla: Bool = false
1006+
) -> (In) -> Out {
1007+
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
1008+
let wrappedFn = {
1009+
(inputs: [CTensorHandle]) -> [CTensorHandle] in
1010+
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(
1011+
capacity: Int(inputs.count))
1012+
var ptr = buffer
1013+
for input in inputs {
1014+
ptr.initialize(to: input)
1015+
ptr = ptr.advanced(by: 1)
1016+
}
1017+
let symbolicIn = In(_owning: buffer)
1018+
let symbolicOut = escapableFn(symbolicIn)
1019+
return symbolicOut.cTensorHandles
1020+
}
1021+
let dtypes = In._typeList.map { $0._cDataType }
1022+
return _trace(with: dtypes, in: wrappedFn)
1023+
}
1024+
// The result is a closure that captures and executes the trace graph
1025+
// function in the trace context.
1026+
return { (input: In) -> (Out) in
1027+
debugLog("Running trace function over input \(input).")
1028+
1029+
debugLog("Getting input state tensor handles.")
1030+
let inputStateTensorHandles = input.cTensorHandles
1031+
let inputTensors = inputStateTensorHandles.map {
1032+
_TFCCreateTensorHandleFromC($0)
1033+
}
1034+
debugLog("Executing trace graph function.")
1035+
let returnValues = traceContext.execute(
1036+
traceeInputs: inputTensors, useXla: useXla)
1037+
1038+
debugLog("Creating output model instance.")
1039+
return Out(_copying: returnValues)
1040+
}
1041+
}
1042+
9961043
/// Trace the given function and return the name of the corresponding
9971044
/// `TF_Function: In -> Out` that was created.
9981045
public func _tffunc<In : TensorGroup, Out : TensorGroup>(

test/TensorFlowRuntime/tracer.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ TracerTests.testAllBackends("TraceWithNoResult") {
129129
expectNearlyEqualWithScalarTensor(8.0, tracedAdd(Tensor<Float>(5.0), three))
130130
}
131131

132+
TracerTests.testAllBackends("TracerWithInOut") {
133+
func addOne(state: Tensor<Int32>) -> (Tensor<Int32>) {
134+
return state + 1
135+
}
136+
let addOneGraph = _graph(addOne)
137+
expectEqual(addOneGraph(Tensor<Int32>(5)), Tensor<Int32>(6))
138+
expectEqual(addOneGraph(Tensor<Int32>(0)), Tensor<Int32>(1))
139+
expectEqual(addOneGraph(Tensor<Int32>(-1)), Tensor<Int32>(0))
140+
}
141+
132142
TracerTests.testAllBackends("Basic_IntermediateTensors") {
133143
func tracee(state: Tensor<Float>, data: Data) -> (Tensor<Float>, Result) {
134144
// Create an intermediate tensor value, which the tracing infra needs to

0 commit comments

Comments
 (0)