@@ -934,17 +934,21 @@ public func _tffunc<State: _TensorArrayProtocolEnhanced, Data: TensorGroup>(
934
934
}
935
935
}
936
936
937
+ internal func _trace< In: TensorGroup , Out: TensorGroup > ( _ fn: ( In ) -> Out ) -> TFFunction {
938
+ let useLazyTensor = _RuntimeConfig. useLazyTensor
939
+ defer { _RuntimeConfig. useLazyTensor = useLazyTensor }
940
+ _RuntimeConfig. useLazyTensor = true
941
+ let trace = LazyTensorTraceBuilder . trace ( fn)
942
+ return TFFunction ( trace: trace)
943
+ }
944
+
937
945
// Trace the given function to generate a TF graph and return a closure that can be used to launch
938
946
// the graph.
939
947
public func _graph< In: TensorGroup , Out: TensorGroup > (
940
948
_ fn: ( In ) -> Out ,
941
949
useXLA: Bool = false
942
950
) -> ( In ) -> Out {
943
- let useLazyTensor = _RuntimeConfig. useLazyTensor
944
- defer { _RuntimeConfig. useLazyTensor = useLazyTensor }
945
- _RuntimeConfig. useLazyTensor = true
946
- let trace = LazyTensorTraceBuilder . trace ( fn)
947
- let tffunc = TFFunction ( trace: trace)
951
+ let tffunc = _trace ( fn)
948
952
return { ( input: In ) -> Out in
949
953
let inputHandles = input. _tensorHandles. map { $0. _tfeTensorHandle }
950
954
let outputHandles = tffunc. execute ( inputHandles, useXLA: useXLA)
@@ -955,11 +959,7 @@ public func _graph<In: TensorGroup, Out: TensorGroup>(
955
959
/// Trace the given function and return the name of the corresponding `TF_Function: In -> Out` that
956
960
/// was created.
957
961
public func _tffunc< In: TensorGroup , Out: TensorGroup > ( _ fn: ( In ) -> Out ) -> String {
958
- let useLazyTensor = _RuntimeConfig. useLazyTensor
959
- defer { _RuntimeConfig. useLazyTensor = useLazyTensor }
960
- _RuntimeConfig. useLazyTensor = true
961
- let trace = LazyTensorTraceBuilder . trace ( fn)
962
- let tffunc = TFFunction ( trace: trace)
962
+ let tffunc = _trace ( fn)
963
963
return tffunc. name
964
964
}
965
965
0 commit comments