Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4961c03

Browse files
committed
Some simple tests and cleanup.
1 parent 253898f commit 4961c03

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,19 @@ extension LazyTensorOperation: TFTensorOperation {
346346
}
347347

348348
func execute() {
349-
// Just run it now.
349+
// If we want to stage this, we will need to add control dependencies.
350+
// For the time-being, just build a TFE_Op and run it.
351+
//
350352
let op = TFE_Op(name, outputCount)
351-
// TODO: Materialize en masse and not one-by-one.
353+
// TODO(https://bugs.swift.org/browse/TF-604):
354+
// Materialize inputs en masse and not one-by-one.
352355
for input in inputs {
353356
switch input {
354357
case .single(let v):
355358
op.addInput(v._tfeTensorHandle)
356-
case .list(let values): do {
357-
for v in values {
358-
op.addInput(v._tfeTensorHandle)
359-
}
359+
case .list(let values):
360+
for v in values {
361+
op.addInput(v._tfeTensorHandle)
360362
}
361363
}
362364
}
@@ -372,12 +374,12 @@ extension LazyTensorOperation: TFTensorOperation {
372374
case .floatArray(let v): op.updateAttribute(name, v)
373375
case .doubleArray(let v): op.updateAttribute(name, v)
374376
case .stringArray(let v): op.updateAttribute(name, v)
375-
case .constTensor(_): assert(false, "Const Tensor cannot be eager attribute.")
377+
case .constTensor(_): fatalError("Const Tensor cannot be eager attribute.")
376378
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
377379
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
378380
case .optionalTensorShape(let v): op.updateAttribute(name, v)
379381
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
380-
case .tensorFunctionPointer(_): assert(false, "Unimplemented")
382+
case .tensorFunctionPointer(_): fatalError("tensorFunctionPointer Unimplemented!")
381383
}
382384
}
383385
op.execute()

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,22 @@ final class LazyTensorEvaluationTests: XCTestCase {
8686
XCTAssertTrue(isMaterialized(sum))
8787
}
8888

89+
func testNoOutputOperations() {
90+
let a = StringTensor("Hello ")
91+
let b = StringTensor("World")
92+
let c = Raw.add(a, b)
93+
// c is not materialized yet.
94+
XCTAssertFalse(isMaterialized(c.handle.handle))
95+
Raw.printV2(c)
96+
// c is materialized now as printV2 would be executed.
97+
XCTAssertTrue(isMaterialized(c.handle.handle))
98+
}
99+
89100
private func isMaterialized<T: TensorFlowScalar>(_ input: Tensor<T>) -> Bool {
90-
let tensor = input.handle.handle
101+
return isMaterialized(input.handle.handle)
102+
}
103+
104+
private func isMaterialized(_ tensor: _AnyTensorHandle) -> Bool {
91105
guard let lazyTensor = tensor as? LazyTensor else { return true }
92106
switch lazyTensor.handle {
93107
case .symbolic(let op, _, _): return op.outputs != nil
@@ -100,6 +114,7 @@ final class LazyTensorEvaluationTests: XCTestCase {
100114
("testMultipleMaterializations", testMultipleMaterializations),
101115
("testSimpleControlFlow", testSimpleControlFlow),
102116
("testSimpleLoop", testSimpleLoop),
117+
("testNoOutputOperations", testNoOutputOperations)
103118
]
104119
}
105120

0 commit comments

Comments
 (0)