Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 200ccd1

Browse files
committed
[WIP] Change Dense.bias type to Optional
1 parent 5d98153 commit 200ccd1

File tree

3 files changed

+211
-8
lines changed

3 files changed

+211
-8
lines changed

Sources/TensorFlow/Layers/Dense.swift

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,25 @@ import _Differentiation
2828
public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
2929
/// The weight matrix.
3030
public var weight: Tensor<Scalar>
31-
/// The bias vector.
32-
public var bias: Tensor<Scalar>
31+
/// The optional bias vector.
32+
public var optionalBias: Tensor<Scalar>?
3333
/// The element-wise activation function.
3434
@noDerivative public let activation: Activation
3535
/// Indicates whether this is a batched dense layer.
3636
@noDerivative internal let batched: Bool
37-
/// Workaround optionals not being handled by AD
38-
@noDerivative private let useBias: Bool
3937

4038
/// The element-wise activation function type.
4139
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
4240

41+
/// The bias vector.
42+
///
43+
/// - Note: returns `Tensor.zero` if the underlying `optionalBias` does not exist.
44+
//@differentiable
45+
public var bias: Tensor<Scalar> {
46+
get { optionalBias ?? .zero }
47+
set { optionalBias = newValue }
48+
}
49+
4350
/// Creates an instance from the given weight, optional bias, and activation function.
4451
///
4552
/// - Note: currently, `weight` is the only differentiability parameter. `bias` can be made a
@@ -55,10 +62,9 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
5562
precondition(
5663
bias == nil || bias!.rank <= 2, "The rank of the 'bias' tensor must be less than 3.")
5764
self.weight = weight
58-
self.bias = bias ?? .zero
65+
self.optionalBias = bias
5966
self.activation = activation
6067
self.batched = weight.rank == 3
61-
useBias = (bias != nil)
6268
}
6369

6470
// TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported.
@@ -81,9 +87,15 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
8187
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
8288
if batched {
8389
let hidden = matmul(input.expandingShape(at: 1), weight).squeezingShape(at: 1)
84-
return activation(useBias ? hidden + bias : hidden)
90+
if let bias = optionalBias {
91+
return activation(hidden + bias)
92+
}
93+
return activation(hidden)
94+
}
95+
if let bias = optionalBias {
96+
return activation(matmul(input, weight) + bias)
8597
}
86-
return activation(useBias ? (matmul(input, weight) + bias) : matmul(input, weight))
98+
return activation(matmul(input, weight))
8799
}
88100
}
89101

@@ -106,9 +118,68 @@ extension Dense {
106118
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
107119
biasInitializer: ParameterInitializer<Scalar> = zeros()
108120
) {
121+
print("Init OLD")
109122
self.init(
110123
weight: weightInitializer([inputSize, outputSize]),
111124
bias: useBias ? biasInitializer([outputSize]) : nil,
112125
activation: activation)
113126
}
127+
128+
/// Creates a `Dense` layer with the specified input size, output size, and element-wise
129+
/// activation function. The weight matrix is created with shape `[inputSize, outputSize]` and
130+
/// the bias vector is created with shape `[outputSize]`.
131+
///
132+
/// - Parameters:
133+
/// - inputSize: The dimensionality of the input space.
134+
/// - outputSize: The dimensionality of the output space.
135+
/// - activation: The activation function to use. The default value is `identity(_:)`.
136+
/// - weightInitializer: Initializer to use for `weight`.
137+
/// - biasInitializer: Initializer to use for `bias`.
138+
public init(
139+
inputSize: Int,
140+
outputSize: Int,
141+
activation: @escaping Activation = identity,
142+
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
143+
biasInitializer: ParameterInitializer<Scalar>? = nil
144+
) {
145+
print("Init NEW")
146+
self.init(
147+
weight: weightInitializer([inputSize, outputSize]),
148+
bias: biasInitializer?([outputSize]),
149+
activation: activation)
150+
}
114151
}
152+
153+
extension Dense.TangentVector {
154+
public init(
155+
weight: Tensor<Scalar>,
156+
bias: Tensor<Scalar>
157+
) {
158+
self.init(weight: weight, optionalBias: .init(bias))
159+
}
160+
161+
/// The bias vector.
162+
///
163+
/// - Note: returns `Tensor.zero` if the underlying `optionalBias` does not exist.
164+
//@differentiable
165+
public var bias: Tensor<Scalar> {
166+
get { optionalBias.value ?? .zero }
167+
set { optionalBias.value = newValue }
168+
}
169+
}
170+
171+
/* extension Optional : KeyPathIterable {
172+
public var allKeyPaths: [PartialKeyPath<Self>] {
173+
if self != nil {
174+
return [ \Optional.unsafelyUnwrapped ]
175+
}
176+
return []
177+
}
178+
179+
public typealias AllKeyPaths = [PartialKeyPath<Self>]
180+
}
181+
182+
extension Optional.TangentVector : KeyPathIterable
183+
{
184+
185+
}*/

Sources/TensorFlow/StdlibExtensions.swift

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,134 @@ extension Collection {
282282
/// Returns the `n`th position in `self`.
283283
func index(atOffset n: Int) -> Index { index(startIndex, offsetBy: n) }
284284
}
285+
286+
extension Optional: EuclideanDifferentiable
287+
where Wrapped: EuclideanDifferentiable {
288+
public var differentiableVectorView: TangentVector { .init(self?.differentiableVectorView) }
289+
}
290+
291+
extension Optional.TangentVector: ElementaryFunctions
292+
where Wrapped.TangentVector: ElementaryFunctions {
293+
/// The square root of `x`.
294+
///
295+
/// For real types, if `x` is negative the result is `.nan`. For complex
296+
/// types there is a branch cut on the negative real axis.
297+
public static func sqrt(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
298+
299+
/// The cosine of `x`, interpreted as an angle in radians.
300+
public static func cos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
301+
302+
/// The sine of `x`, interpreted as an angle in radians.
303+
public static func sin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
304+
305+
/// The tangent of `x`, interpreted as an angle in radians.
306+
public static func tan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
307+
308+
/// The inverse cosine of `x` in radians.
309+
public static func acos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
310+
311+
/// The inverse sine of `x` in radians.
312+
public static func asin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
313+
314+
/// The inverse tangent of `x` in radians.
315+
public static func atan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
316+
317+
/// The hyperbolic cosine of `x`.
318+
public static func cosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
319+
320+
/// The hyperbolic sine of `x`.
321+
public static func sinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
322+
323+
/// The hyperbolic tangent of `x`.
324+
public static func tanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
325+
326+
/// The inverse hyperbolic cosine of `x`.
327+
public static func acosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
328+
329+
/// The inverse hyperbolic sine of `x`.
330+
public static func asinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
331+
332+
/// The inverse hyperbolic tangent of `x`.
333+
public static func atanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
334+
335+
/// The exponential function applied to `x`, or `e**x`.
336+
public static func exp(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
337+
338+
/// Two raised to to power `x`.
339+
public static func exp2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
340+
341+
/// Ten raised to to power `x`.
342+
public static func exp10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
343+
344+
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
345+
public static func expm1(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
346+
347+
/// The natural logarithm of `x`.
348+
public static func log(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
349+
350+
/// The base-two logarithm of `x`.
351+
public static func log2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
352+
353+
/// The base-ten logarithm of `x`.
354+
public static func log10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
355+
356+
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
357+
public static func log1p(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
358+
359+
/// `exp(y log(x))` computed without loss of intermediate precision.
360+
///
361+
/// For real types, if `x` is negative the result is NaN, even if `y` has
362+
/// an integral value. For complex types, there is a branch cut on the
363+
/// negative real axis.
364+
public static func pow(_ x: Self, _ y: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
365+
366+
/// `x` raised to the `n`th power.
367+
///
368+
/// The product of `n` copies of `x`.
369+
public static func pow(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.pow(x, n)})) }
370+
371+
/// The `n`th root of `x`.
372+
///
373+
/// For real types, if `x` is negative and `n` is even, the result is NaN.
374+
/// For complex types, there is a branch cut along the negative real axis.
375+
public static func root(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.root(x, n)})) }
376+
}
377+
378+
extension Optional.TangentVector: PointwiseMultiplicative
379+
where Wrapped.TangentVector: PointwiseMultiplicative {
380+
public static var one: Self {
381+
.init(Wrapped.TangentVector.one)
382+
}
383+
384+
public var reciprocal: Self { .init(value.map { $0.reciprocal }) }
385+
386+
public static func .* (lhs: Self, rhs: Self) -> Self {
387+
switch (lhs.value, rhs.value) {
388+
case let (x?, y?): return Self(x .* y)
389+
default: return Self(nil)
390+
}
391+
}
392+
393+
public static func .*= (lhs: inout Self, rhs: Self) {
394+
lhs = lhs .* rhs
395+
}
396+
}
397+
398+
extension Optional.TangentVector: VectorProtocol
399+
where Wrapped.TangentVector: VectorProtocol {
400+
public typealias VectorSpaceScalar = Wrapped.TangentVector.VectorSpaceScalar
401+
402+
public func adding(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.adding(x) }) }
403+
404+
public mutating func add(_ x: VectorSpaceScalar) { value?.add(x) }
405+
406+
public func subtracting(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.subtracting(x) }) }
407+
408+
public mutating func subtract(_ x: VectorSpaceScalar) { value?.subtract(x) }
409+
410+
public func scaled(by scale: VectorSpaceScalar) -> Self { .init(value.map { $0.scaled(by: scale) }) }
411+
412+
public mutating func scale(by scale: VectorSpaceScalar) {
413+
value?.scale(by: scale)
414+
}
415+
}

Tests/TensorFlowTests/TrivialModelTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ final class TrivialModelTests: XCTestCase {
5151
return meanSquaredError(predicted: ŷ, expected: y)
5252
}
5353
optimizer.update(&classifier, along: 𝛁model)
54+
dump(𝛁model)
5455
}
5556
}
5657
let ŷ = classifier.inferring(from: x)

0 commit comments

Comments
 (0)