Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit bd42b53

Browse files
committed
Add a simple test.
1 parent 9af0d99 commit bd42b53

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import TensorFlow
17+
import CTensorFlow
18+
19+
extension LazyTensorOperation {
20+
/// Returns true if the outputs have been materialized.
21+
var isMaterialized: Bool { outputs != nil }
22+
}
23+
24+
final class LazyTensorShapeInferenceTests: XCTestCase {
25+
override class func setUp() {
26+
super.setUp()
27+
_ThreadLocalState.useLazyTensor = true
28+
}
29+
30+
override class func tearDown() {
31+
super.tearDown()
32+
_ThreadLocalState.useLazyTensor = false
33+
}
34+
35+
func testSimpleShapeComputations() {
36+
let a = Tensor<Float>(shape: [3, 1], scalars: [1.0, 2.0, 3.0])
37+
let b = Tensor<Float>(shape: [1, 3], scalars: [1.0, 2.0, 3.0])
38+
let c = Tensor<Float>(shape: [1, 3], scalars: [4.0, 5.0, 6.0])
39+
let w = a * b
40+
let wLazyTensorOperation = w._lazyTensor!.lazyTensorOperation!
41+
let x = w * c
42+
let xLazyTensorOperation = x._lazyTensor!.lazyTensorOperation!
43+
44+
// Make sure that `w` and `x` are not materialized.
45+
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
46+
XCTAssertFalse(xLazyTensorOperation.isMaterialized)
47+
48+
// Examine shape of w and confirm no materialization has happened.
49+
let wShape = w.shape
50+
XCTAssertEqual(wShape.rank, 2)
51+
XCTAssertEqual(wShape.dimensions, [3, 3])
52+
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
53+
XCTAssertFalse(xLazyTensorOperation.isMaterialized)
54+
55+
let xShape = x.shape
56+
XCTAssertEqual(xShape.rank, 2)
57+
XCTAssertEqual(xShape.dimensions, [3, 3])
58+
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
59+
XCTAssertFalse(xLazyTensorOperation.isMaterialized)
60+
61+
// Trigger materialization.
62+
let _ = x._rawTensorHandle
63+
XCTAssertTrue(wLazyTensorOperation.isMaterialized)
64+
XCTAssertTrue(xLazyTensorOperation.isMaterialized)
65+
}
66+
67+
static var allTests = [
68+
("testSimpleShapeComputations", testSimpleShapeComputations)
69+
]
70+
}
71+

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public func allTests() -> [XCTestCaseEntry] {
2626
testCase(LazyTensorTraceTests.allTests),
2727
testCase(LazyTensorExplicitTraceTests.allTests),
2828
testCase(LazyTensorOperationTests.allTests),
29+
testCase(LazyTensorShapeInferenceTests.allTests),
2930
testCase(LazyTensorTFFunctionBuilderTests.allTests),
3031
testCase(LazyTensorEvaluationTests.allTests),
3132
testCase(LossTests.allTests),

0 commit comments

Comments
 (0)