Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 253898f

Browse files
committed
no-output operations to be eagerly executed in lazy tensor operation.
operation attrs -> attributes execute no output camelcase.
1 parent 76a69b7 commit 253898f

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,43 @@ extension LazyTensorOperation: TFTensorOperation {
345345
fatalError("Unimplemented [TFFunction] attribute.")
346346
}
347347

348-
func execute() {}
348+
func execute() {
349+
// Just run it now.
350+
let op = TFE_Op(name, outputCount)
351+
// TODO: Materialize en masse and not one-by-one.
352+
for input in inputs {
353+
switch input {
354+
case .single(let v):
355+
op.addInput(v._tfeTensorHandle)
356+
case .list(let values): do {
357+
for v in values {
358+
op.addInput(v._tfeTensorHandle)
359+
}
360+
}
361+
}
362+
}
363+
for (name, value) in attributes {
364+
switch value {
365+
case .boolValue(let v): op.updateAttribute(name, v)
366+
case .intValue(let v): op.updateAttribute(name, v)
367+
case .floatValue(let v): op.updateAttribute(name, v)
368+
case .doubleValue(let v): op.updateAttribute(name, v)
369+
case .stringValue(let v): op.updateAttribute(name, v)
370+
case .boolArray(let v): op.updateAttribute(name, v)
371+
case .intArray(let v): op.updateAttribute(name, v)
372+
case .floatArray(let v): op.updateAttribute(name, v)
373+
case .doubleArray(let v): op.updateAttribute(name, v)
374+
case .stringArray(let v): op.updateAttribute(name, v)
375+
case .constTensor(_): assert(false, "Const Tensor cannot be eager attribute.")
376+
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
377+
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
378+
case .optionalTensorShape(let v): op.updateAttribute(name, v)
379+
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
380+
case .tensorFunctionPointer(_): assert(false, "Unimplemented")
381+
}
382+
}
383+
op.execute()
384+
}
349385

350386
func execute<T0: TensorArrayProtocol>(
351387
_ count0: Int

0 commit comments

Comments
 (0)