Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0e60587

Browse files
authored
Trigger execution of operations with no outputs. (#296)
1 parent 60f63e1 commit 0e60587

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,45 @@ extension LazyTensorOperation: TFTensorOperation {
345345
fatalError("Unimplemented [TFFunction] attribute.")
346346
}
347347

348-
func execute() {}
348+
func execute() {
349+
// If we want to stage this, we will need to add control dependencies.
350+
// For the time-being, just build a TFE_Op and run it.
351+
//
352+
let op = TFE_Op(name, outputCount)
353+
// TODO(https://bugs.swift.org/browse/TF-604):
354+
// Materialize inputs en masse and not one-by-one.
355+
for input in inputs {
356+
switch input {
357+
case .single(let v):
358+
op.addInput(v._tfeTensorHandle)
359+
case .list(let values):
360+
for v in values {
361+
op.addInput(v._tfeTensorHandle)
362+
}
363+
}
364+
}
365+
for (name, value) in attributes {
366+
switch value {
367+
case .boolValue(let v): op.updateAttribute(name, v)
368+
case .intValue(let v): op.updateAttribute(name, v)
369+
case .floatValue(let v): op.updateAttribute(name, v)
370+
case .doubleValue(let v): op.updateAttribute(name, v)
371+
case .stringValue(let v): op.updateAttribute(name, v)
372+
case .boolArray(let v): op.updateAttribute(name, v)
373+
case .intArray(let v): op.updateAttribute(name, v)
374+
case .floatArray(let v): op.updateAttribute(name, v)
375+
case .doubleArray(let v): op.updateAttribute(name, v)
376+
case .stringArray(let v): op.updateAttribute(name, v)
377+
case .constTensor(_): fatalError("Const Tensor cannot be eager attribute.")
378+
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
379+
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
380+
case .optionalTensorShape(let v): op.updateAttribute(name, v)
381+
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
382+
case .tensorFunctionPointer(_): fatalError("tensorFunctionPointer Unimplemented!")
383+
}
384+
}
385+
op.execute()
386+
}
349387

350388
func execute<T0: TensorArrayProtocol>(
351389
_ count0: Int

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,44 @@ final class LazyTensorEvaluationTests: XCTestCase {
8686
XCTAssertTrue(isMaterialized(sum))
8787
}
8888

89+
struct SimpleOutput: TensorGroup {
90+
let a: TensorHandle<Int32>
91+
let b: TensorHandle<Int32>
92+
}
93+
94+
func testNoOutputOperations() {
95+
let elements1: Tensor<Int32> = [0, 1, 2]
96+
let elements2: Tensor<Int32> = [10, 11, 12]
97+
let outputTypes = [Int32.tensorFlowDataType, Int32.tensorFlowDataType]
98+
let outputShapes: [TensorShape?] = [nil, nil]
99+
let dataset: VariantHandle = Raw.tensorSliceDataset(
100+
components: [elements1, elements2],
101+
outputShapes: outputShapes
102+
)
103+
let iterator: ResourceHandle = Raw.iteratorV2(sharedName: "blah",
104+
container: "earth", outputTypes: outputTypes, outputShapes: outputShapes
105+
)
106+
// `dataset` and `iterator` should not be materialized yet.
107+
XCTAssertFalse(isMaterialized(dataset.handle))
108+
XCTAssertFalse(isMaterialized(iterator.handle))
109+
Raw.makeIterator(dataset: dataset, iterator: iterator)
110+
111+
// `dataset` and `iterator` should be materialized now as
112+
// makeIterator executes.
113+
XCTAssertTrue(isMaterialized(dataset.handle))
114+
XCTAssertTrue(isMaterialized(iterator.handle))
115+
let next: SimpleOutput = Raw.iteratorGetNext(
116+
iterator: iterator, outputShapes: outputShapes
117+
)
118+
XCTAssertEqual(Tensor(handle: next.a).scalarized(), 0)
119+
XCTAssertEqual(Tensor(handle: next.b).scalarized(), 10)
120+
}
121+
89122
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {
90-
let tensor = input.handle.handle
123+
return isMaterialized(input.handle.handle)
124+
}
125+
126+
private func isMaterialized(_ tensor: _AnyTensorHandle) -> Bool {
91127
guard let lazyTensor = tensor as? LazyTensor else { return true }
92128
switch lazyTensor.handle {
93129
case .symbolic(let op, _, _): return op.outputs != nil
@@ -100,6 +136,7 @@ final class LazyTensorEvaluationTests: XCTestCase {
100136
("testMultipleMaterializations", testMultipleMaterializations),
101137
("testSimpleControlFlow", testSimpleControlFlow),
102138
("testSimpleLoop", testSimpleLoop),
139+
("testNoOutputOperations", testNoOutputOperations)
103140
]
104141
}
105142

0 commit comments

Comments
 (0)