Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Change Dense.bias type to Optional #1062

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 66 additions & 25 deletions Sources/TensorFlow/Layers/Dense.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,32 @@ import _Differentiation
public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
/// The weight matrix.
public var weight: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// The optional bias vector.
public var optionalBias: Tensor<Scalar>?
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// Indicates whether this is a batched dense layer.
@noDerivative internal let batched: Bool
/// Workaround optionals not being handled by AD
@noDerivative private let useBias: Bool

/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates an instance from the given weight, optional bias, and activation function.
/// The bias vector.
///
/// - Note: currently, `weight` is the only differentiability parameter. `bias` can be made a
/// differentiability parameter after `Optional` conditionally conforms to `Differentiable`:
/// TF-499.
@differentiable(wrt: weight)
/// - Note: Returns `Tensor.zero` if the underlying `optionalBias` does not exist.
@differentiable
public var bias: Tensor<Scalar> {
get {
if let bias = optionalBias {
return bias
}
return .zero
}
set { optionalBias = newValue }
}

/// Creates an instance from the given weight, optional bias, and activation function.
@differentiable(wrt: (weight, bias))
public init(
weight: Tensor<Scalar>,
bias: Tensor<Scalar>? = nil,
Expand All @@ -55,22 +63,9 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
precondition(
bias == nil || bias!.rank <= 2, "The rank of the 'bias' tensor must be less than 3.")
self.weight = weight
self.bias = bias ?? .zero
self.optionalBias = bias
self.activation = activation
self.batched = weight.rank == 3
useBias = (bias != nil)
}

// TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported.
@derivative(of: init, wrt: weight)
@usableFromInline
static func vjpInit(
weight: Tensor<Scalar>,
bias: Tensor<Scalar>? = nil,
activation: @escaping Activation
) -> (value: Self, pullback: (TangentVector) -> Tensor<Scalar>) {
let value = Dense(weight: weight, bias: bias, activation: activation)
return (value, { v in v.weight })
}

/// Returns the output obtained from applying the layer to the given input.
Expand All @@ -81,9 +76,15 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
if batched {
let hidden = matmul(input.expandingShape(at: 1), weight).squeezingShape(at: 1)
return activation(useBias ? hidden + bias : hidden)
if let bias = optionalBias {
return activation(hidden + bias)
}
return activation(hidden)
}
if let bias = optionalBias {
return activation(matmul(input, weight) + bias)
}
return activation(useBias ? (matmul(input, weight) + bias) : matmul(input, weight))
return activation(matmul(input, weight))
}
}

Expand Down Expand Up @@ -111,4 +112,44 @@ extension Dense {
bias: useBias ? biasInitializer([outputSize]) : nil,
activation: activation)
}

/// Creates a `Dense` layer with the specified input size, output size, and element-wise
/// activation function. The weight matrix is created with shape `[inputSize, outputSize]` and
/// the bias vector is created with shape `[outputSize]`.
///
/// - Parameters:
/// - inputSize: The dimensionality of the input space.
/// - outputSize: The dimensionality of the output space.
/// - activation: The activation function to use. The default value is `identity(_:)`.
/// - weightInitializer: Initializer to use for `weight`.
/// - biasInitializer: Initializer to use for `bias`.
public init(
inputSize: Int,
outputSize: Int,
activation: @escaping Activation = identity,
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar>? = nil
) {
self.init(
weight: weightInitializer([inputSize, outputSize]),
bias: biasInitializer?([outputSize]),
activation: activation)
}
}

extension Dense.TangentVector {
public init(
weight: Tensor<Scalar>,
bias: Tensor<Scalar>
) {
self.init(weight: weight, optionalBias: .init(bias))
}

/// The bias vector.
///
/// - Note: returns `Tensor.zero` if the underlying `optionalBias` does not exist.
public var bias: Tensor<Scalar> {
get { optionalBias.value ?? .zero }
set { optionalBias.value = newValue }
}
}
142 changes: 142 additions & 0 deletions Sources/TensorFlow/StdlibExtensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,145 @@ extension Collection {
/// Returns the `n`th position in `self`.
func index(atOffset n: Int) -> Index { index(startIndex, offsetBy: n) }
}

extension Optional: EuclideanDifferentiable
where Wrapped: EuclideanDifferentiable {
public var differentiableVectorView: TangentVector { .init(self?.differentiableVectorView) }
}

extension Optional.TangentVector: ElementaryFunctions
where Wrapped.TangentVector: ElementaryFunctions {
/// The square root of `x`.
///
/// For real types, if `x` is negative the result is `.nan`. For complex
/// types there is a branch cut on the negative real axis.
public static func sqrt(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sqrt)) }

/// The cosine of `x`, interpreted as an angle in radians.
public static func cos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cos)) }

/// The sine of `x`, interpreted as an angle in radians.
public static func sin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sin)) }

/// The tangent of `x`, interpreted as an angle in radians.
public static func tan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tan)) }

/// The inverse cosine of `x` in radians.
public static func acos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acos)) }

/// The inverse sine of `x` in radians.
public static func asin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asin)) }

/// The inverse tangent of `x` in radians.
public static func atan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atan)) }

/// The hyperbolic cosine of `x`.
public static func cosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cosh)) }

/// The hyperbolic sine of `x`.
public static func sinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sinh)) }

/// The hyperbolic tangent of `x`.
public static func tanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tanh)) }

/// The inverse hyperbolic cosine of `x`.
public static func acosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acosh)) }

/// The inverse hyperbolic sine of `x`.
public static func asinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asinh)) }

/// The inverse hyperbolic tangent of `x`.
public static func atanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atanh)) }

/// The exponential function applied to `x`, or `e**x`.
public static func exp(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp)) }

/// Two raised to to power `x`.
public static func exp2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp2)) }

/// Ten raised to to power `x`.
public static func exp10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp10)) }

/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
public static func expm1(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.expm1)) }

/// The natural logarithm of `x`.
public static func log(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log)) }

/// The base-two logarithm of `x`.
public static func log2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log2)) }

/// The base-ten logarithm of `x`.
public static func log10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log10)) }

/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
public static func log1p(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log1p)) }

/// `exp(y log(x))` computed without loss of intermediate precision.
///
/// For real types, if `x` is negative the result is NaN, even if `y` has
/// an integral value. For complex types, there is a branch cut on the
/// negative real axis.
public static func pow(_ x: Self, _ y: Self) -> Self {
switch (x.value, y.value) {
case let (x?, y?): return Self(Wrapped.TangentVector.pow(x, y))
default: return Self(nil)
}
}

/// `x` raised to the `n`th power.
///
/// The product of `n` copies of `x`.
public static func pow(_ x: Self, _ n: Int) -> Self {
Self(x.value.map({ x in Wrapped.TangentVector.pow(x, n) }))
}

/// The `n`th root of `x`.
///
/// For real types, if `x` is negative and `n` is even, the result is NaN.
/// For complex types, there is a branch cut along the negative real axis.
public static func root(_ x: Self, _ n: Int) -> Self {
Self(x.value.map({ x in Wrapped.TangentVector.root(x, n) }))
}
}

extension Optional.TangentVector: PointwiseMultiplicative
where Wrapped.TangentVector: PointwiseMultiplicative {
public static var one: Self {
Self(Wrapped.TangentVector.one)
}

public var reciprocal: Self { Self(value.map { $0.reciprocal }) }

public static func .* (lhs: Self, rhs: Self) -> Self {
switch (lhs.value, rhs.value) {
case let (x?, y?): return Self(x .* y)
default: return Self(nil)
}
}

public static func .*= (lhs: inout Self, rhs: Self) {
lhs = lhs .* rhs
}
}

extension Optional.TangentVector: VectorProtocol
where Wrapped.TangentVector: VectorProtocol {
public typealias VectorSpaceScalar = Wrapped.TangentVector.VectorSpaceScalar

public func adding(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.adding(x) }) }

public mutating func add(_ x: VectorSpaceScalar) { value?.add(x) }

public func subtracting(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.subtracting(x) }) }

public mutating func subtract(_ x: VectorSpaceScalar) { value?.subtract(x) }

public func scaled(by scale: VectorSpaceScalar) -> Self {
Self(value.map { $0.scaled(by: scale) })
}

public mutating func scale(by scale: VectorSpaceScalar) {
value?.scale(by: scale)
}
}