Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 506d89d

Browse files
committed
Rewrite _tffunc using LazyTensor.
1 parent ef48ae9 commit 506d89d

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

Sources/TensorFlow/Core/LazyTensorTFFunctionBuilder.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class TFFunction {
221221
let cTFFunction: CTFFunction
222222
let outputCount: Int
223223
let outputGroupCounts: [Int]
224+
var name: String { String(cString: TF_FunctionName(cTFFunction)!) }
224225

225226
init(trace: LazyTensorTrace, name: String? = nil) {
226227
let status: CTFStatus = TF_NewStatus()

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -976,23 +976,12 @@ public func _graph<In: TensorGroup, Out: TensorGroup>(
976976
/// Trace the given function and return the name of the corresponding `TF_Function: In -> Out` that
977977
/// was created.
978978
public func _tffunc<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> String {
979-
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
980-
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
981-
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(capacity: Int(inputs.count))
982-
var ptr = buffer
983-
for input in inputs {
984-
ptr.initialize(to: input)
985-
ptr = ptr.advanced(by: 1)
986-
}
987-
let symbolicIn = In(_owning: buffer)
988-
let symbolicOut = escapableFn(symbolicIn)
989-
return symbolicOut.cTensorHandles
990-
}
991-
992-
let dtypes = In._typeList.map { $0._cDataType }
993-
return _trace(with: dtypes, in: wrappedFn)
994-
}
995-
return traceContext.specializeTFFunction(with: [])
979+
let useLazyTensor = _RuntimeConfig.useLazyTensor
980+
defer { _RuntimeConfig.useLazyTensor = useLazyTensor }
981+
_RuntimeConfig.useLazyTensor = true
982+
let trace = LazyTensorTraceBuilder.trace(fn)
983+
let tffunc = TFFunction(trace: trace)
984+
return tffunc.name
996985
}
997986

998987
internal extension _ExecutionContext {

0 commit comments

Comments
 (0)