Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3d816c9

Browse files
committed
Change _graph to lazy tensor.
1 parent a4d2336 commit 3d816c9

File tree

1 file changed

+9
-30
lines changed

1 file changed

+9
-30
lines changed

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -940,36 +940,15 @@ public func _graph<In: TensorGroup, Out: TensorGroup>(
940940
_ fn: (In) -> Out,
941941
useXLA: Bool = false
942942
) -> (In) -> Out {
943-
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
944-
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
945-
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
946-
var ptr = buffer
947-
for input in inputs {
948-
ptr.initialize(to: input)
949-
ptr = ptr.advanced(by: 1)
950-
}
951-
let symbolicIn = In(_owning: buffer)
952-
let symbolicOut = escapableFn(symbolicIn)
953-
return symbolicOut.cTensorHandles
954-
}
955-
let dtypes = In._typeList.map { $0._cDataType }
956-
return _trace(with: dtypes, in: wrappedFn)
957-
}
958-
// The result is a closure that captures and executes the trace graph function in the trace
959-
// context.
960-
return { (input: In) -> (Out) in
961-
debugLog("Running trace function over input \(input).")
962-
963-
debugLog("Getting input state tensor handles.")
964-
let inputStateTensorHandles = input.cTensorHandles
965-
let inputTensors = inputStateTensorHandles.map {
966-
TFETensorHandle(_owning: $0)
967-
}
968-
debugLog("Executing trace graph function.")
969-
let returnValues = traceContext.execute(traceeInputs: inputTensors, useXLA: useXLA)
970-
971-
debugLog("Creating output model instance.")
972-
return Out(_owning: returnValues)
943+
let useLazyTensor = _RuntimeConfig.useLazyTensor
944+
defer { _RuntimeConfig.useLazyTensor = useLazyTensor }
945+
_RuntimeConfig.useLazyTensor = true
946+
let trace = LazyTensorTraceBuilder.trace(fn)
947+
let tffunc = TFFunction(trace: trace)
948+
return {(input: In) -> Out in
949+
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
950+
let outputHandles = tffunc.execute(inputHandles, useXLA: useXLA)
951+
return Out(_handles: outputHandles)
973952
}
974953
}
975954

0 commit comments

Comments
 (0)