Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Expose the available devices from the _ExecutionContext. #398

Merged
merged 2 commits into from
Jul 29, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Core/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ public final class _ExecutionContext {
/// List of devices available to this execution context.
/// Devices are represented by their names in TensorFlow notation.
/// See documentation for `withDevice(named:perform:)` to learn about device names.
private var deviceNames: [String] = []
public private(set) var deviceNames: [String] = []

/// The buffer storing a serialized TensorFlow config proto.
public let tensorFlowConfig: UnsafeMutablePointer<TF_Buffer>
Expand Down
31 changes: 31 additions & 0 deletions Tests/TensorFlowTests/RuntimeTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
import TensorFlow // Note: not imported as @testable in order to test the public API

final class RuntimeTests: XCTestCase {
func testDeviceNames() {
let deviceNames = _ExecutionContext.global.deviceNames
XCTAssert(deviceNames.count > 0, "Missing CPU device, got: \(deviceNames)")
let cpu0 = deviceNames.filter { $0.hasSuffix("/device:CPU:0") }
XCTAssertEqual(cpu0.count, 1, "All devices: \(deviceNames)")
}
}

extension RuntimeTests {
static var allTests = [
("testDeviceNames", testDeviceNames),
]
}
20 changes: 11 additions & 9 deletions Tests/TensorFlowTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,27 @@ import XCTest

#if !os(macOS)
public func allTests() -> [XCTestCaseEntry] {
// Please ensure the test cases remain alphabetized.
return [
testCase(UtilitiesTests.allTests),
testCase(LossTests.allTests),
testCase(PRNGTests.allTests),
testCase(TrivialModelTests.allTests),
testCase(SequentialTests.allTests),
testCase(LayerTests.allTests),
testCase(TensorTests.allTests),
testCase(TensorGroupTests.allTests),
testCase(BasicOperatorTests.allTests),
testCase(ComparisonOperatorTests.allTests),
testCase(DatasetTests.allTests),
testCase(MathOperatorTests.allTests),
testCase(LayerTests.allTests),
testCase(LazyTensorTests.allTests),
testCase(LazyTensorTraceTests.allTests),
testCase(LazyTensorExplicitTraceTests.allTests),
testCase(LazyTensorOperationTests.allTests),
testCase(LazyTensorTFFunctionBuilderTests.allTests),
testCase(LazyTensorEvaluationTests.allTests),
testCase(LossTests.allTests),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually good idea to keep unrelated changes in a separate PR?

testCase(MathOperatorTests.allTests),
testCase(PRNGTests.allTests),
testCase(RuntimeTests.allTests),
testCase(SequentialTests.allTests),
testCase(TensorTests.allTests),
testCase(TensorGroupTests.allTests),
testCase(TrivialModelTests.allTests),
testCase(UtilitiesTests.allTests),
]
}
#endif