Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f4d32f7

Browse files
committed
Add a test for nested tracing.
1 parent ad758dc commit f4d32f7

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,37 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
103103
%2 = Mul[T: float](%0, %1)
104104
}
105105
""")
106-
let outputs = runTrace(
107-
trace: trace,
108-
input: Tensor<Float>(23.0))
106+
let outputs = runTrace(trace: trace, input: Tensor<Float>(23.0))
109107
XCTAssertEqual(outputs.count, 1)
110108
XCTAssertEqual(outputs[0].valueDescription, "115.0")
111109
}
112110

111+
func testNestedTracing() {
112+
func square(input: Tensor<Float>) -> Tensor<Float> {
113+
return input * input
114+
}
115+
116+
func nestedTrace(input: Tensor<Float>) -> Tensor<Float> {
117+
let trace = LazyTensorTraceBuilder.trace(square)
118+
let outputs = runTrace(trace: trace, input: Tensor<Float>(3.0))
119+
XCTAssertEqual(outputs.count, 1)
120+
let handle = TensorHandle<Float>(handle: outputs[0])
121+
let y = Tensor<Float>(handle: handle)
122+
return y + input
123+
}
124+
125+
let trace = LazyTensorTraceBuilder.trace(nestedTrace)
126+
XCTAssertEqual(trace.description,
127+
"""
128+
lazyTrace_3(%0: float) -> (%2) {
129+
%1 = Const[dtype: float, value: 9.0]()
130+
%2 = Add[T: float](%1, %0)
131+
}
132+
""")
133+
let outputs = runTrace(trace: trace, input: Tensor<Float>(4.0))
134+
XCTAssertEqual(outputs.count, 1)
135+
XCTAssertEqual(outputs[0].valueDescription, "13.0")
136+
}
113137

114138
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
115139
let tffunc = TFFunction(trace: trace)
@@ -122,6 +146,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
122146
("testSingleInput", testSingleInput),
123147
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
124148
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
125-
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors)
149+
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors),
150+
("testNestedTracing", testNestedTracing)
126151
]
127152
}

0 commit comments

Comments
 (0)