Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 65b1a5e

Browse files
committed
Add test for closure captures.
1 parent 50fbabf commit 65b1a5e

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,29 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
6666
XCTAssertEqual(outputs[1].valueDescription, "13.0")
6767
}
6868

69+
func testClosureCaptures() {
70+
let x = Tensor<Float>(10.0)
71+
let y = x + x
72+
func fn(input: Tensor<Float>) -> Tensor<Float> {
73+
return input * y
74+
}
75+
let trace = LazyTensorTraceBuilder.trace(fn)
76+
/// Note that the computation x + x is encoded in the trace.
77+
XCTAssertEqual(trace.description,
78+
"""
79+
lazyTrace_4(%0: float) -> (%3) {
80+
%1 = Const[dtype: float, value: 10.0]()
81+
%2 = Add[T: float](%1, %1)
82+
%3 = Mul[T: float](%0, %2)
83+
}
84+
""")
85+
let outputs = runTrace(
86+
trace: trace,
87+
input: Tensor<Float>(5.0))
88+
XCTAssertEqual(outputs.count, 1)
89+
XCTAssertEqual(outputs[0].valueDescription, "100.0")
90+
}
91+
6992
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
7093
let tffunc = TFFunction(trace: trace)
7194
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
@@ -75,6 +98,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
7598

7699
static var allTests = [
77100
("testSingleInput", testSingleInput),
78-
("testTensorGroupInputOutputs", testTensorGroupInputOutputs)
101+
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
102+
("testClosureCaptures", testClosureCaptures)
79103
]
80104
}

0 commit comments

Comments
 (0)