Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5c8b8ee

Browse files
saetaeaplatanios
authored andcommitted
Expose the available devices from the _ExecutionContext. (#398)
1 parent 1de468a commit 5c8b8ee

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ public final class _ExecutionContext {
550550
/// List of devices available to this execution context.
551551
/// Devices are represented by their names in TensorFlow notation.
552552
/// See documentation for `withDevice(named:perform:)` to learn about device names.
553-
private var deviceNames: [String] = []
553+
public private(set) var deviceNames: [String] = []
554554

555555
/// The buffer storing a serialized TensorFlow config proto.
556556
public let tensorFlowConfig: UnsafeMutablePointer<TF_Buffer>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
import TensorFlow // Note: not imported as @testable in order to test the public API.
17+
18+
final class RuntimeTests: XCTestCase {
19+
func testDeviceNames() {
20+
let deviceNames = _ExecutionContext.global.deviceNames
21+
XCTAssert(deviceNames.count > 0, "Missing CPU device, got: \(deviceNames)")
22+
let cpu0 = deviceNames.filter { $0.hasSuffix("/device:CPU:0") }
23+
XCTAssertEqual(cpu0.count, 1, "All devices: \(deviceNames)")
24+
}
25+
}
26+
27+
extension RuntimeTests {
28+
static var allTests = [
29+
("testDeviceNames", testDeviceNames),
30+
]
31+
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,27 @@ import XCTest
1616

1717
#if !os(macOS)
1818
public func allTests() -> [XCTestCaseEntry] {
19+
// Please ensure the test cases remain alphabetized.
1920
return [
20-
testCase(UtilitiesTests.allTests),
21-
testCase(LossTests.allTests),
22-
testCase(PRNGTests.allTests),
23-
testCase(TrivialModelTests.allTests),
24-
testCase(SequentialTests.allTests),
25-
testCase(LayerTests.allTests),
26-
testCase(TensorTests.allTests),
27-
testCase(TensorGroupTests.allTests),
2821
testCase(BasicOperatorTests.allTests),
2922
testCase(ComparisonOperatorTests.allTests),
3023
testCase(DatasetTests.allTests),
31-
testCase(MathOperatorTests.allTests),
24+
testCase(LayerTests.allTests),
3225
testCase(LazyTensorTests.allTests),
3326
testCase(LazyTensorTraceTests.allTests),
3427
testCase(LazyTensorExplicitTraceTests.allTests),
3528
testCase(LazyTensorOperationTests.allTests),
3629
testCase(LazyTensorTFFunctionBuilderTests.allTests),
3730
testCase(LazyTensorEvaluationTests.allTests),
31+
testCase(LossTests.allTests),
32+
testCase(MathOperatorTests.allTests),
33+
testCase(PRNGTests.allTests),
34+
testCase(RuntimeTests.allTests),
35+
testCase(SequentialTests.allTests),
36+
testCase(TensorTests.allTests),
37+
testCase(TensorGroupTests.allTests),
38+
testCase(TrivialModelTests.allTests),
39+
testCase(UtilitiesTests.allTests),
3840
]
3941
}
4042
#endif

0 commit comments

Comments
 (0)