Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit acdc7af

Browse files
committed
Add test that catpures non-tensors.
1 parent 65b1a5e commit acdc7af

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
6666
XCTAssertEqual(outputs[1].valueDescription, "13.0")
6767
}
6868

69-
func testClosureCaptures() {
69+
func testClosureCapturesOfTensors() {
7070
let x = Tensor<Float>(10.0)
7171
let y = x + x
7272
func fn(input: Tensor<Float>) -> Tensor<Float> {
@@ -89,6 +89,28 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
8989
XCTAssertEqual(outputs[0].valueDescription, "100.0")
9090
}
9191

92+
func testClosureCapturesOfNonTensors() {
93+
let x: Float = 5.0
94+
func fn(input: Tensor<Float>) -> Tensor<Float> {
95+
return input * Tensor<Float>(x)
96+
}
97+
let trace = LazyTensorTraceBuilder.trace(fn)
98+
/// Note that the computation x + x is encoded in the trace.
99+
XCTAssertEqual(trace.description,
100+
"""
101+
lazyTrace_3(%0: float) -> (%2) {
102+
%1 = Const[dtype: float, value: 5.0]()
103+
%2 = Mul[T: float](%0, %1)
104+
}
105+
""")
106+
let outputs = runTrace(
107+
trace: trace,
108+
input: Tensor<Float>(23.0))
109+
XCTAssertEqual(outputs.count, 1)
110+
XCTAssertEqual(outputs[0].valueDescription, "115.0")
111+
}
112+
113+
92114
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
93115
let tffunc = TFFunction(trace: trace)
94116
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
@@ -99,6 +121,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
99121
static var allTests = [
100122
("testSingleInput", testSingleInput),
101123
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
102-
("testClosureCaptures", testClosureCaptures)
124+
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
125+
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors)
103126
]
104127
}

0 commit comments

Comments
 (0)