Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4291d11

Browse files
authored
Materialize inputs en masse when executing a 0-output lazy operation. (#406)
* Materialize inputs en-masse when executing a 0-output lazy operation. Related to https://bugs.swift.org/browse/TF-604
1 parent 5ffe1c3 commit 4291d11

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class LazyTensorHandle: _AnyTensorHandle {
8181
case .concrete(_): return nil
8282
}
8383
}
84-
84+
8585
// Liveness tracking for LazyTensorOperations
8686
//
8787
static func isLive(_ op: LazyTensorOperation) -> Bool {
@@ -374,9 +374,26 @@ extension LazyTensorOperation: TFTensorOperation {
374374
// If we want to stage this, we will need to add control dependencies.
375375
// For the time-being, just build a TFE_Op and run it.
376376
//
377+
// Collect all the unmaterialized inputs.
378+
var unmaterializedInputs = Array<LazyTensorOperation>()
379+
unmaterializedInputs.reserveCapacity(inputs.count)
380+
for input in inputs {
381+
switch input {
382+
case .single(let v):
383+
if let lazyOperation = v.lazyTensorOperation {
384+
unmaterializedInputs.append(lazyOperation)
385+
}
386+
case .list(let values):
387+
unmaterializedInputs.append(
388+
contentsOf: values.lazy.compactMap { $0.lazyTensorOperation }
389+
)
390+
}
391+
}
392+
// Materialize the inputs now.
393+
LazyTensorOperation.materialize(targets: unmaterializedInputs)
394+
395+
// Build the TFEOp and execute.
377396
let op = TFE_Op(name, outputCount)
378-
// TODO(https://bugs.swift.org/browse/TF-604):
379-
// Materialize inputs en masse and not one-by-one.
380397
for input in inputs {
381398
switch input {
382399
case .single(let v):
@@ -800,7 +817,7 @@ extension LazyTensorOperation {
800817
// Return materialized outputs if any.
801818
if let outputs = outputs { return outputs }
802819

803-
materializeLiveTensors()
820+
LazyTensorOperation.materialize(targets: [self])
804821

805822
// Our outputs should have been updated by now. Otherwise,
806823
// something terrible happened!
@@ -838,8 +855,8 @@ extension LazyTensorOperation {
838855
inputs = inputs.map { materializedAsNeeded(input: $0) }
839856
}
840857

841-
private func materializeLiveTensors() {
842-
let traceInfo = LazyTensorTraceBuilder.materializationTraceInfo(self)
858+
static func materialize(targets: [LazyTensorOperation]) {
859+
let traceInfo = LazyTensorTraceBuilder.materializationTraceInfo(targets)
843860
debugLog("Extracted trace:\n\(traceInfo.trace)")
844861

845862
let function = TFFunction(trace: traceInfo.trace)

0 commit comments

Comments
 (0)