@@ -218,7 +218,8 @@ private class TraceContext {
218
218
219
219
/// Execute the trace graph function, and return the list of output tensors
220
220
/// from the trace execution. These output tensors are owned by the caller.
221
- func execute( traceeInputs: [ _AnyTensorHandle ] ) -> [ CTensorHandle ] {
221
+ func execute(
222
+ traceeInputs: [ _AnyTensorHandle ] , useXla: Bool = false ) -> [ CTensorHandle ] {
222
223
// We must be in the `notTracing` enum mode.
223
224
internalConsistencyCheck ( _RuntimeConfig. traceState. context == nil )
224
225
internalConsistencyCheck ( traceGraphFn != nil )
@@ -236,6 +237,11 @@ private class TraceContext {
236
237
checkOk ( status)
237
238
}
238
239
240
+ if useXla {
241
+ debugLog ( " Enabling XLA compilation " )
242
+ TFE_OpSetAttrBool ( op, " _XlaCompile " , 1 )
243
+ }
244
+
239
245
debugLog ( " Adding \( traceeInputs. count) tracee input tensors. " )
240
246
internalConsistencyCheck ( symbolicInputs. count == traceeInputs. count
241
247
+ Int( additionalInputTensorCount) )
@@ -993,6 +999,47 @@ public func _tffunc<State : _TensorArrayProtocolEnhanced,
993
999
}
994
1000
}
995
1001
1002
+ // Trace the given function to generate a TF graph and return a closure
1003
+ // that can be used to launch the graph.
1004
+ public func _graph< In : TensorGroup , Out : TensorGroup > (
1005
+ _ fn: ( In ) -> Out , useXla: Bool = false
1006
+ ) -> ( In ) -> Out {
1007
+ let traceContext : TraceContext = withoutActuallyEscaping ( fn) { escapableFn in
1008
+ let wrappedFn = {
1009
+ ( inputs: [ CTensorHandle ] ) -> [ CTensorHandle ] in
1010
+ let buffer = UnsafeMutablePointer< CTensorHandle> . allocate(
1011
+ capacity: Int ( inputs. count) )
1012
+ var ptr = buffer
1013
+ for input in inputs {
1014
+ ptr. initialize ( to: input)
1015
+ ptr = ptr. advanced ( by: 1 )
1016
+ }
1017
+ let symbolicIn = In ( _owning: buffer)
1018
+ let symbolicOut = escapableFn ( symbolicIn)
1019
+ return symbolicOut. cTensorHandles
1020
+ }
1021
+ let dtypes = In . _typeList. map { $0. _cDataType }
1022
+ return _trace ( with: dtypes, in: wrappedFn)
1023
+ }
1024
+ // The result is a closure that captures and executes the trace graph
1025
+ // function in the trace context.
1026
+ return { ( input: In ) -> ( Out ) in
1027
+ debugLog ( " Running trace function over input \( input) . " )
1028
+
1029
+ debugLog ( " Getting input state tensor handles. " )
1030
+ let inputStateTensorHandles = input. cTensorHandles
1031
+ let inputTensors = inputStateTensorHandles. map {
1032
+ _TFCCreateTensorHandleFromC ( $0)
1033
+ }
1034
+ debugLog ( " Executing trace graph function. " )
1035
+ let returnValues = traceContext. execute (
1036
+ traceeInputs: inputTensors, useXla: useXla)
1037
+
1038
+ debugLog ( " Creating output model instance. " )
1039
+ return Out ( _copying: returnValues)
1040
+ }
1041
+ }
1042
+
996
1043
/// Trace the given function and return the name of the corresponding
997
1044
/// `TF_Function: In -> Out` that was created.
998
1045
public func _tffunc< In : TensorGroup , Out : TensorGroup > (
0 commit comments