Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 275fe88

Browse files
rickwierengasaeta
authored andcommitted
Update Dropout to be more concise (#560)
* Move Dense and Dropout layers to separate files * keep the ordering of the layers the same as before * remove empty lines * Make Dropout initializer doc clearer * Remove redundant methods from Dropout
1 parent 83a4ef9 commit 275fe88

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

Sources/TensorFlow/Layers/Dropout.swift

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,14 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
4646

4747
/// Creates a dropout layer.
4848
///
49-
/// - Parameter probability: The drop probability.
49+
/// - Parameter probability: The probability of a node dropping out.
5050
/// - Precondition: probability must be a value between 0 and 1 (inclusive).
5151
public init(probability: Double) {
5252
precondition(0...1 ~= probability,
5353
"Probability must be a value between 0 and 1 (inclusive) but is \(probability)")
5454
self.probability = probability
5555
}
5656

57-
@differentiable
58-
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
59-
return input._droppingOut(probability: probability)
60-
}
61-
62-
@differentiable
63-
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
64-
return input
65-
}
66-
6757
/// Returns the output obtained from applying the layer to the given input.
6858
///
6959
/// - Parameter input: The input to the layer.
@@ -72,9 +62,9 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
7262
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
7363
switch Context.local.learningPhase {
7464
case .training:
75-
return applyingTraining(to: input)
65+
return input._droppingOut(probability: probability)
7666
case .inference:
77-
return applyingInference(to: input)
67+
return input
7868
}
7969
}
8070
}

0 commit comments

Comments
 (0)