Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9832310

Browse files
authored
Explicitly qualify LearningPhase.inference to prevent ambiguity errors. (#92)
Fixes `swift build` and `swift test`.
1 parent 6f11c8e commit 9832310

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public extension Layer {
4646
/// - Returns: The inference output.
4747
@differentiable
4848
func inferring(from input: Input) -> Output {
49-
return withLearningPhase(.inference) {
49+
return withLearningPhase(LearningPhase.inference) {
5050
applied(to: input)
5151
}
5252
}
@@ -57,7 +57,7 @@ public extension Layer {
5757
internal func _vjpInferring(from input: Input)
5858
-> (value: Output, pullback: (Output.CotangentVector)
5959
-> (CotangentVector, Input.CotangentVector)) {
60-
return withLearningPhase(.inference) {
60+
return withLearningPhase(LearningPhase.inference) {
6161
let (output, pullback) = appliedForBackpropagation(to: input)
6262
return (output, { v in pullback(v) })
6363
}

Tests/DeepLearningTests/ContextTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ final class ContextTests: XCTestCase {
2222
let dropout = Dropout<Float>(probability: 0.5)
2323
let x = Tensor<Float>(repeating: 1.0, shape: [5, 5])
2424
XCTAssertEqual(dropout.applied(to: x), x)
25-
withLearningPhase(.inference) {
25+
withLearningPhase(LearningPhase.inference) {
2626
XCTAssertEqual(dropout.applied(to: x), x)
27-
withLearningPhase(.training) {
27+
withLearningPhase(LearningPhase.training) {
2828
XCTAssertNotEqual(dropout.applied(to: x), x)
2929
}
3030
XCTAssertEqual(dropout.applied(to: x), x)
@@ -39,7 +39,7 @@ final class ContextTests: XCTestCase {
3939
DispatchQueue.concurrentPerform(iterations: 10) { i in
4040
if i.isMultiple(of: 2) {
4141
XCTAssertEqual(dropout.applied(to: x), x)
42-
withLearningPhase(.training) {
42+
withLearningPhase(LearningPhase.training) {
4343
XCTAssertNotEqual(dropout.applied(to: x), x)
4444
}
4545
XCTAssertEqual(dropout.applied(to: x), x)

0 commit comments

Comments
 (0)