Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ee249fa

Browse files
committed
Clean up.
There are two remaining test failures: ``` $ swift test --filter OptimizerTests Tests/TensorFlowTests/OptimizerTests.swift:117: error: -[TensorFlowTests.OptimizerTests testAdaMax] : XCTAssertTrue failed Tests/TensorFlowTests/OptimizerTests.swift:123: error: -[TensorFlowTests.OptimizerTests testAMSGrad] : XCTAssertTrue failed ```
1 parent 200ccd1 commit ee249fa

File tree

3 files changed

+51
-59
lines changed

3 files changed

+51
-59
lines changed

Sources/TensorFlow/Layers/Dense.swift

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,20 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
4040

4141
/// The bias vector.
4242
///
43-
/// - Note: returns `Tensor.zero` if the underlying `optionalBias` does not exist.
44-
//@differentiable
43+
/// - Note: Returns `Tensor.zero` if the underlying `optionalBias` does not exist.
44+
@differentiable
4545
public var bias: Tensor<Scalar> {
46-
get { optionalBias ?? .zero }
46+
get {
47+
if let bias = optionalBias {
48+
return bias
49+
}
50+
return .zero
51+
}
4752
set { optionalBias = newValue }
4853
}
4954

5055
/// Creates an instance from the given weight, optional bias, and activation function.
51-
///
52-
/// - Note: currently, `weight` is the only differentiability parameter. `bias` can be made a
53-
/// differentiability parameter after `Optional` conditionally conforms to `Differentiable`:
54-
/// TF-499.
55-
@differentiable(wrt: weight)
56+
@differentiable(wrt: (weight, bias))
5657
public init(
5758
weight: Tensor<Scalar>,
5859
bias: Tensor<Scalar>? = nil,
@@ -67,18 +68,6 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
6768
self.batched = weight.rank == 3
6869
}
6970

70-
// TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported.
71-
@derivative(of: init, wrt: weight)
72-
@usableFromInline
73-
static func vjpInit(
74-
weight: Tensor<Scalar>,
75-
bias: Tensor<Scalar>? = nil,
76-
activation: @escaping Activation
77-
) -> (value: Self, pullback: (TangentVector) -> Tensor<Scalar>) {
78-
let value = Dense(weight: weight, bias: bias, activation: activation)
79-
return (value, { v in v.weight })
80-
}
81-
8271
/// Returns the output obtained from applying the layer to the given input.
8372
///
8473
/// - Parameter input: The input to the layer.
@@ -118,7 +107,6 @@ extension Dense {
118107
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
119108
biasInitializer: ParameterInitializer<Scalar> = zeros()
120109
) {
121-
print("Init OLD")
122110
self.init(
123111
weight: weightInitializer([inputSize, outputSize]),
124112
bias: useBias ? biasInitializer([outputSize]) : nil,
@@ -142,7 +130,6 @@ extension Dense {
142130
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
143131
biasInitializer: ParameterInitializer<Scalar>? = nil
144132
) {
145-
print("Init NEW")
146133
self.init(
147134
weight: weightInitializer([inputSize, outputSize]),
148135
bias: biasInitializer?([outputSize]),
@@ -168,18 +155,13 @@ extension Dense.TangentVector {
168155
}
169156
}
170157

171-
/* extension Optional : KeyPathIterable {
158+
extension Optional: KeyPathIterable {
172159
public var allKeyPaths: [PartialKeyPath<Self>] {
173160
if self != nil {
174-
return [ \Optional.unsafelyUnwrapped ]
161+
return [\.!]
175162
}
176163
return []
177164
}
178165

179166
public typealias AllKeyPaths = [PartialKeyPath<Self>]
180167
}
181-
182-
extension Optional.TangentVector : KeyPathIterable
183-
{
184-
185-
}*/

Sources/TensorFlow/StdlibExtensions.swift

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -294,94 +294,103 @@ where Wrapped.TangentVector: ElementaryFunctions {
294294
///
295295
/// For real types, if `x` is negative the result is `.nan`. For complex
296296
/// types there is a branch cut on the negative real axis.
297-
public static func sqrt(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
297+
public static func sqrt(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sqrt)) }
298298

299299
/// The cosine of `x`, interpreted as an angle in radians.
300-
public static func cos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
300+
public static func cos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cos)) }
301301

302302
/// The sine of `x`, interpreted as an angle in radians.
303-
public static func sin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
303+
public static func sin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sin)) }
304304

305305
/// The tangent of `x`, interpreted as an angle in radians.
306-
public static func tan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
306+
public static func tan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tan)) }
307307

308308
/// The inverse cosine of `x` in radians.
309-
public static func acos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
309+
public static func acos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acos)) }
310310

311311
/// The inverse sine of `x` in radians.
312-
public static func asin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
312+
public static func asin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asin)) }
313313

314314
/// The inverse tangent of `x` in radians.
315-
public static func atan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
315+
public static func atan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atan)) }
316316

317317
/// The hyperbolic cosine of `x`.
318-
public static func cosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
318+
public static func cosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cosh)) }
319319

320320
/// The hyperbolic sine of `x`.
321-
public static func sinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
321+
public static func sinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sinh)) }
322322

323323
/// The hyperbolic tangent of `x`.
324-
public static func tanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
324+
public static func tanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tanh)) }
325325

326326
/// The inverse hyperbolic cosine of `x`.
327-
public static func acosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
327+
public static func acosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acosh)) }
328328

329329
/// The inverse hyperbolic sine of `x`.
330-
public static func asinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
330+
public static func asinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asinh)) }
331331

332332
/// The inverse hyperbolic tangent of `x`.
333-
public static func atanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
333+
public static func atanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atanh)) }
334334

335335
/// The exponential function applied to `x`, or `e**x`.
336-
public static func exp(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
336+
public static func exp(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp)) }
337337

338338
/// Two raised to to power `x`.
339-
public static func exp2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
339+
public static func exp2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp2)) }
340340

341341
/// Ten raised to to power `x`.
342-
public static func exp10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
342+
public static func exp10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp10)) }
343343

344344
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
345-
public static func expm1(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
345+
public static func expm1(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.expm1)) }
346346

347347
/// The natural logarithm of `x`.
348-
public static func log(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
348+
public static func log(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log)) }
349349

350350
/// The base-two logarithm of `x`.
351-
public static func log2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
351+
public static func log2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log2)) }
352352

353353
/// The base-ten logarithm of `x`.
354-
public static func log10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
354+
public static func log10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log10)) }
355355

356356
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
357-
public static func log1p(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
357+
public static func log1p(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log1p)) }
358358

359359
/// `exp(y log(x))` computed without loss of intermediate precision.
360360
///
361361
/// For real types, if `x` is negative the result is NaN, even if `y` has
362362
/// an integral value. For complex types, there is a branch cut on the
363363
/// negative real axis.
364-
public static func pow(_ x: Self, _ y: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
364+
public static func pow(_ x: Self, _ y: Self) -> Self {
365+
switch (x.value, y.value) {
366+
case let (x?, y?): return Self(Wrapped.TangentVector.pow(x, y))
367+
default: return Self(nil)
368+
}
369+
}
365370

366371
/// `x` raised to the `n`th power.
367372
///
368373
/// The product of `n` copies of `x`.
369-
public static func pow(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.pow(x, n)})) }
374+
public static func pow(_ x: Self, _ n: Int) -> Self {
375+
Self(x.value.map({ x in Wrapped.TangentVector.pow(x, n) }))
376+
}
370377

371378
/// The `n`th root of `x`.
372379
///
373380
/// For real types, if `x` is negative and `n` is even, the result is NaN.
374381
/// For complex types, there is a branch cut along the negative real axis.
375-
public static func root(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.root(x, n)})) }
382+
public static func root(_ x: Self, _ n: Int) -> Self {
383+
Self(x.value.map({ x in Wrapped.TangentVector.root(x, n) }))
384+
}
376385
}
377386

378387
extension Optional.TangentVector: PointwiseMultiplicative
379388
where Wrapped.TangentVector: PointwiseMultiplicative {
380389
public static var one: Self {
381-
.init(Wrapped.TangentVector.one)
390+
Self(Wrapped.TangentVector.one)
382391
}
383392

384-
public var reciprocal: Self { .init(value.map { $0.reciprocal }) }
393+
public var reciprocal: Self { Self(value.map { $0.reciprocal }) }
385394

386395
public static func .* (lhs: Self, rhs: Self) -> Self {
387396
switch (lhs.value, rhs.value) {
@@ -399,15 +408,17 @@ extension Optional.TangentVector: VectorProtocol
399408
where Wrapped.TangentVector: VectorProtocol {
400409
public typealias VectorSpaceScalar = Wrapped.TangentVector.VectorSpaceScalar
401410

402-
public func adding(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.adding(x) }) }
411+
public func adding(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.adding(x) }) }
403412

404413
public mutating func add(_ x: VectorSpaceScalar) { value?.add(x) }
405414

406-
public func subtracting(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.subtracting(x) }) }
415+
public func subtracting(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.subtracting(x) }) }
407416

408417
public mutating func subtract(_ x: VectorSpaceScalar) { value?.subtract(x) }
409418

410-
public func scaled(by scale: VectorSpaceScalar) -> Self { .init(value.map { $0.scaled(by: scale) }) }
419+
public func scaled(by scale: VectorSpaceScalar) -> Self {
420+
Self(value.map { $0.scaled(by: scale) })
421+
}
411422

412423
public mutating func scale(by scale: VectorSpaceScalar) {
413424
value?.scale(by: scale)

Tests/TensorFlowTests/TrivialModelTests.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ final class TrivialModelTests: XCTestCase {
5151
return meanSquaredError(predicted: ŷ, expected: y)
5252
}
5353
optimizer.update(&classifier, along: 𝛁model)
54-
dump(𝛁model)
5554
}
5655
}
5756
let ŷ = classifier.inferring(from: x)

0 commit comments

Comments
 (0)