Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 50fbabf

Browse files
committed
Added tests for explicit tracing.
1 parent fdc0c45 commit 50fbabf

File tree

3 files changed

+81
-14
lines changed

3 files changed

+81
-14
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import TensorFlow
18+
import CTensorFlow
19+
20+
final class LazyTensorExplicitTraceTests: XCTestCase {
21+
override class func setUp() {
22+
super.setUp()
23+
_RuntimeConfig.useLazyTensor = true
24+
}
25+
26+
override class func tearDown() {
27+
super.tearDown()
28+
_RuntimeConfig.useLazyTensor = false
29+
}
30+
31+
func testSingleInput() {
32+
func fn(x: Tensor<Float>) -> Tensor<Float> { return x + x }
33+
let trace = LazyTensorTraceBuilder.trace(fn)
34+
XCTAssertEqual(trace.description,
35+
"""
36+
lazyTrace_2(%0: float) -> (%1) {
37+
%1 = Add[T: float](%0, %0)
38+
}
39+
""")
40+
let outputs = runTrace(trace: trace, input: Tensor<Float>(10.0))
41+
XCTAssertEqual(outputs.count, 1)
42+
XCTAssertEqual(outputs[0].valueDescription, "20.0")
43+
}
44+
45+
func testTensorGroupInputOutputs() {
46+
typealias TensorFloatInt32Pair = Zip2TensorGroup<Tensor<Float>, Tensor<Int32>>
47+
typealias TensorInt32FloatPair = Zip2TensorGroup<Tensor<Int32>, Tensor<Float>>
48+
func fn(input: TensorFloatInt32Pair) -> TensorInt32FloatPair {
49+
return TensorInt32FloatPair(input.second * 4, input.first + 3.0)
50+
}
51+
let trace = LazyTensorTraceBuilder.trace(fn)
52+
XCTAssertEqual(trace.description,
53+
"""
54+
lazyTrace_6(%0: float, %1: int32) -> (%3, %5) {
55+
%2 = Const[dtype: int32, value: 4]()
56+
%3 = Mul[T: int32](%1, %2)
57+
%4 = Const[dtype: float, value: 3.0]()
58+
%5 = Add[T: float](%0, %4)
59+
}
60+
""")
61+
let outputs = runTrace(
62+
trace: trace,
63+
input: TensorFloatInt32Pair(Tensor<Float>(10.0), Tensor<Int32>(5)))
64+
XCTAssertEqual(outputs.count, 2)
65+
XCTAssertEqual(outputs[0].valueDescription, "20")
66+
XCTAssertEqual(outputs[1].valueDescription, "13.0")
67+
}
68+
69+
private func runTrace(trace: LazyTensorTrace, input: TensorGroup) -> [TFETensorHandle] {
70+
let tffunc = TFFunction(trace: trace)
71+
let inputHandles = input._tensorHandles.map { $0._tfeTensorHandle }
72+
let outputHandles = tffunc.execute(inputHandles)
73+
return outputHandles
74+
}
75+
76+
static var allTests = [
77+
("testSingleInput", testSingleInput),
78+
("testTensorGroupInputOutputs", testTensorGroupInputOutputs)
79+
]
80+
}

Tests/TensorFlowTests/LazyTensorTraceTests.swift

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,6 @@ final class LazyTensorTraceTests: XCTestCase {
206206
XCTAssertEqual(z.scalarized(), 9.0)
207207
}
208208

209-
func testTracing() {
210-
func fn(x: Tensor<Float>) -> Tensor<Float>{
211-
return x + x
212-
}
213-
let trace = LazyTensorTraceBuilder.trace(fn)
214-
XCTAssertEqual(trace.description,
215-
"""
216-
lazyTrace_2(%0: float) -> (%1) {
217-
%1 = Add[T: float](%0, %0)
218-
}
219-
""")
220-
}
221-
222209
private func lazyTensorOperation<T: TensorFlowScalar>(
223210
_ input: Tensor<T>
224211
) -> LazyTensorOperation? {
@@ -250,6 +237,5 @@ final class LazyTensorTraceTests: XCTestCase {
250237
("testSimpleControlFlow", testSimpleControlFlow),
251238
("testManualConstPromotion", testManualConstPromotion),
252239
("testConstPromotion", testConstPromotion),
253-
("testTracing", testTracing)
254240
]
255241
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public func allTests() -> [XCTestCaseEntry] {
3131
testCase(MathOperatorTests.allTests),
3232
testCase(LazyTensorTests.allTests),
3333
testCase(LazyTensorTraceTests.allTests),
34+
testCase(LazyTensorExplicitTraceTests.allTests),
3435
testCase(LazyTensorOperationTests.allTests),
3536
testCase(LazyTensorTFFunctionBuilderTests.allTests),
3637
testCase(LazyTensorEvaluationTests.allTests),

0 commit comments

Comments
 (0)