@@ -219,7 +219,7 @@ private class TraceContext {
219
219
/// Execute the trace graph function, and return the list of output tensors
220
220
/// from the trace execution. These output tensors are owned by the caller.
221
221
func execute(
222
- traceeInputs: [ _AnyTensorHandle ] , useXla : Bool = false ) -> [ CTensorHandle ] {
222
+ traceeInputs: [ _AnyTensorHandle ] , useXLA : Bool = false ) -> [ CTensorHandle ] {
223
223
// We must be in the `notTracing` enum mode.
224
224
internalConsistencyCheck ( _RuntimeConfig. traceState. context == nil )
225
225
internalConsistencyCheck ( traceGraphFn != nil )
@@ -237,7 +237,7 @@ private class TraceContext {
237
237
checkOk ( status)
238
238
}
239
239
240
- if useXla {
240
+ if useXLA {
241
241
debugLog ( " Enabling XLA compilation " )
242
242
TFE_OpSetAttrBool ( op, " _XlaCompile " , 1 )
243
243
}
@@ -1002,11 +1002,10 @@ public func _tffunc<State : _TensorArrayProtocolEnhanced,
1002
1002
// Trace the given function to generate a TF graph and return a closure
1003
1003
// that can be used to launch the graph.
1004
1004
public func _graph< In : TensorGroup , Out : TensorGroup > (
1005
- _ fn: ( In ) -> Out , useXla : Bool = false
1005
+ _ fn: ( In ) -> Out , useXLA : Bool = false
1006
1006
) -> ( In ) -> Out {
1007
1007
let traceContext : TraceContext = withoutActuallyEscaping ( fn) { escapableFn in
1008
- let wrappedFn = {
1009
- ( inputs: [ CTensorHandle ] ) -> [ CTensorHandle ] in
1008
+ let wrappedFn = { ( inputs: [ CTensorHandle ] ) -> [ CTensorHandle ] in
1010
1009
let buffer = UnsafeMutablePointer< CTensorHandle> . allocate(
1011
1010
capacity: Int ( inputs. count) )
1012
1011
var ptr = buffer
@@ -1033,7 +1032,7 @@ public func _graph<In : TensorGroup, Out : TensorGroup>(
1033
1032
}
1034
1033
debugLog ( " Executing trace graph function. " )
1035
1034
let returnValues = traceContext. execute (
1036
- traceeInputs: inputTensors, useXla : useXla )
1035
+ traceeInputs: inputTensors, useXLA : useXLA )
1037
1036
1038
1037
debugLog ( " Creating output model instance. " )
1039
1038
return Out ( _copying: returnValues)
0 commit comments