Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1d87344

Browse files
committed
Refactor common code between _tffunc and _graph.
1 parent 0bffffc commit 1d87344

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -934,17 +934,21 @@ public func _tffunc<State: _TensorArrayProtocolEnhanced, Data: TensorGroup>(
934934
}
935935
}
936936

937+
internal func _trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> TFFunction {
938+
let useLazyTensor = _RuntimeConfig.useLazyTensor
939+
defer { _RuntimeConfig.useLazyTensor = useLazyTensor }
940+
_RuntimeConfig.useLazyTensor = true
941+
let trace = LazyTensorTraceBuilder.trace(fn)
942+
return TFFunction(trace: trace)
943+
}
944+
937945
// Trace the given function to generate a TF graph and return a closure that can be used to launch
938946
// the graph.
939947
public func _graph<In: TensorGroup, Out: TensorGroup>(
940948
_ fn: (In) -> Out,
941949
useXLA: Bool = false
942950
) -> (In) -> Out {
943-
let useLazyTensor = _RuntimeConfig.useLazyTensor
944-
defer { _RuntimeConfig.useLazyTensor = useLazyTensor }
945-
_RuntimeConfig.useLazyTensor = true
946-
let trace = LazyTensorTraceBuilder.trace(fn)
947-
let tffunc = TFFunction(trace: trace)
951+
let tffunc = _trace(fn)
948952
return {(input: In) -> Out in
949953
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
950954
let outputHandles = tffunc.execute(inputHandles, useXLA: useXLA)
@@ -955,11 +959,7 @@ public func _graph<In: TensorGroup, Out: TensorGroup>(
955959
/// Trace the given function and return the name of the corresponding `TF_Function: In -> Out` that
956960
/// was created.
957961
public func _tffunc<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> String {
958-
let useLazyTensor = _RuntimeConfig.useLazyTensor
959-
defer { _RuntimeConfig.useLazyTensor = useLazyTensor }
960-
_RuntimeConfig.useLazyTensor = true
961-
let trace = LazyTensorTraceBuilder.trace(fn)
962-
let tffunc = TFFunction(trace: trace)
962+
let tffunc = _trace(fn)
963963
return tffunc.name
964964
}
965965

0 commit comments

Comments
 (0)