@@ -81,7 +81,7 @@ class LazyTensorHandle: _AnyTensorHandle {
81
81
case . concrete( _) : return nil
82
82
}
83
83
}
84
-
84
+
85
85
// Liveness tracking for LazyTensorOperations
86
86
//
87
87
static func isLive( _ op: LazyTensorOperation ) -> Bool {
@@ -374,9 +374,26 @@ extension LazyTensorOperation: TFTensorOperation {
374
374
// If we want to stage this, we will need to add control dependencies.
375
375
// For the time-being, just build a TFE_Op and run it.
376
376
//
377
+ // Collect all the unmaterialized inputs.
378
+ var unmaterializedInputs = Array < LazyTensorOperation > ( )
379
+ unmaterializedInputs. reserveCapacity ( inputs. count)
380
+ for input in inputs {
381
+ switch input {
382
+ case . single( let v) :
383
+ if let lazyOperation = v. lazyTensorOperation {
384
+ unmaterializedInputs. append ( lazyOperation)
385
+ }
386
+ case . list( let values) :
387
+ unmaterializedInputs. append (
388
+ contentsOf: values. lazy. compactMap { $0. lazyTensorOperation }
389
+ )
390
+ }
391
+ }
392
+ // Materialize the inputs now.
393
+ LazyTensorOperation . materialize ( targets: unmaterializedInputs)
394
+
395
+ // Build the TFEOp and execute.
377
396
let op = TFE_Op ( name, outputCount)
378
- // TODO(https://bugs.swift.org/browse/TF-604):
379
- // Materialize inputs en masse and not one-by-one.
380
397
for input in inputs {
381
398
switch input {
382
399
case . single( let v) :
@@ -800,7 +817,7 @@ extension LazyTensorOperation {
800
817
// Return materialized outputs if any.
801
818
if let outputs = outputs { return outputs }
802
819
803
- materializeLiveTensors ( )
820
+ LazyTensorOperation . materialize ( targets : [ self ] )
804
821
805
822
// Our outputs should have been updated by now. Otherwise,
806
823
// something terrible happened!
@@ -838,8 +855,8 @@ extension LazyTensorOperation {
838
855
inputs = inputs. map { materializedAsNeeded ( input: $0) }
839
856
}
840
857
841
- private func materializeLiveTensors ( ) {
842
- let traceInfo = LazyTensorTraceBuilder . materializationTraceInfo ( self )
858
+ static func materialize ( targets : [ LazyTensorOperation ] ) {
859
+ let traceInfo = LazyTensorTraceBuilder . materializationTraceInfo ( targets )
843
860
debugLog ( " Extracted trace: \n \( traceInfo. trace) " )
844
861
845
862
let function = TFFunction ( trace: traceInfo. trace)
0 commit comments