Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0bffffc

Browse files
committed
Add testCallableTrace test.
1 parent cc8aca8 commit 0bffffc

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
135135
XCTAssertEqual(outputs[0].valueDescription, "13.0")
136136
}
137137

138+
func testCallableTrace() {
139+
func square(input: Tensor<Float>) -> Tensor<Float> {
140+
return input * input
141+
}
142+
let tracedSquare = _graph(square)
143+
XCTAssertEqual(tracedSquare(Tensor<Float>(10.0)).scalarized(), 100.0)
144+
XCTAssertEqual(tracedSquare(Tensor<Float>(5.0)).scalarized(), 25.0)
145+
}
146+
138147
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
139148
let tffunc = TFFunction(trace: trace)
140149
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
@@ -147,6 +156,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
147156
("testTensorGroupInputOutputs", testTensorGroupInputOutputs),
148157
("testClosureCapturesOfTensors", testClosureCapturesOfTensors),
149158
("testClosureCapturesOfNonTensors", testClosureCapturesOfNonTensors),
150-
("testNestedTracing", testNestedTracing)
159+
("testNestedTracing", testNestedTracing),
160+
("testCallableTrace", testCallableTrace)
151161
]
152162
}

0 commit comments

Comments
 (0)