Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6d1a850

Browse files
committed
Clean up.
There are two remaining test failures: ``` $ swift test --filter OptimizerTests Tests/TensorFlowTests/OptimizerTests.swift:117: error: -[TensorFlowTests.OptimizerTests testAdaMax] : XCTAssertTrue failed Tests/TensorFlowTests/OptimizerTests.swift:123: error: -[TensorFlowTests.OptimizerTests testAMSGrad] : XCTAssertTrue failed ```
1 parent 200ccd1 commit 6d1a850

File tree

3 files changed

+80
-88
lines changed

3 files changed

+80
-88
lines changed

Sources/TensorFlow/Layers/Dense.swift

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,20 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
4040

4141
/// The bias vector.
4242
///
43-
/// - Note: returns `Tensor.zero` if the underlying `optionalBias` does not exist.
44-
//@differentiable
43+
/// - Note: Returns `Tensor.zero` if the underlying `optionalBias` does not exist.
44+
@differentiable
4545
public var bias: Tensor<Scalar> {
46-
get { optionalBias ?? .zero }
46+
get {
47+
if let bias = optionalBias {
48+
return bias
49+
}
50+
return .zero
51+
}
4752
set { optionalBias = newValue }
4853
}
4954

5055
/// Creates an instance from the given weight, optional bias, and activation function.
51-
///
52-
/// - Note: currently, `weight` is the only differentiability parameter. `bias` can be made a
53-
/// differentiability parameter after `Optional` conditionally conforms to `Differentiable`:
54-
/// TF-499.
55-
@differentiable(wrt: weight)
56+
@differentiable(wrt: (weight, bias))
5657
public init(
5758
weight: Tensor<Scalar>,
5859
bias: Tensor<Scalar>? = nil,
@@ -67,18 +68,6 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
6768
self.batched = weight.rank == 3
6869
}
6970

70-
// TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported.
71-
@derivative(of: init, wrt: weight)
72-
@usableFromInline
73-
static func vjpInit(
74-
weight: Tensor<Scalar>,
75-
bias: Tensor<Scalar>? = nil,
76-
activation: @escaping Activation
77-
) -> (value: Self, pullback: (TangentVector) -> Tensor<Scalar>) {
78-
let value = Dense(weight: weight, bias: bias, activation: activation)
79-
return (value, { v in v.weight })
80-
}
81-
8271
/// Returns the output obtained from applying the layer to the given input.
8372
///
8473
/// - Parameter input: The input to the layer.
@@ -118,7 +107,6 @@ extension Dense {
118107
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
119108
biasInitializer: ParameterInitializer<Scalar> = zeros()
120109
) {
121-
print("Init OLD")
122110
self.init(
123111
weight: weightInitializer([inputSize, outputSize]),
124112
bias: useBias ? biasInitializer([outputSize]) : nil,
@@ -142,7 +130,6 @@ extension Dense {
142130
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
143131
biasInitializer: ParameterInitializer<Scalar>? = nil
144132
) {
145-
print("Init NEW")
146133
self.init(
147134
weight: weightInitializer([inputSize, outputSize]),
148135
bias: biasInitializer?([outputSize]),
@@ -168,18 +155,13 @@ extension Dense.TangentVector {
168155
}
169156
}
170157

171-
/* extension Optional : KeyPathIterable {
158+
extension Optional: KeyPathIterable {
172159
public var allKeyPaths: [PartialKeyPath<Self>] {
173160
if self != nil {
174-
return [ \Optional.unsafelyUnwrapped ]
161+
return [\.!]
175162
}
176163
return []
177164
}
178165

179166
public typealias AllKeyPaths = [PartialKeyPath<Self>]
180167
}
181-
182-
extension Optional.TangentVector : KeyPathIterable
183-
{
184-
185-
}*/

Sources/TensorFlow/StdlibExtensions.swift

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -111,85 +111,85 @@ where Element: Differentiable & ElementaryFunctions {
111111
///
112112
/// For real types, if `x` is negative the result is `.nan`. For complex
113113
/// types there is a branch cut on the negative real axis.
114-
public static func sqrt(_ x: Self) -> Self { .init(Array.sqrt(x.base)) }
114+
public static func sqrt(_ x: Self) -> Self { Self(Array.sqrt(x.base)) }
115115

116116
/// The cosine of `x`, interpreted as an angle in radians.
117-
public static func cos(_ x: Self) -> Self { .init(Array.cos(x.base)) }
117+
public static func cos(_ x: Self) -> Self { Self(Array.cos(x.base)) }
118118

119119
/// The sine of `x`, interpreted as an angle in radians.
120-
public static func sin(_ x: Self) -> Self { .init(Array.sin(x.base)) }
120+
public static func sin(_ x: Self) -> Self { Self(Array.sin(x.base)) }
121121

122122
/// The tangent of `x`, interpreted as an angle in radians.
123-
public static func tan(_ x: Self) -> Self { .init(Array.tan(x.base)) }
123+
public static func tan(_ x: Self) -> Self { Self(Array.tan(x.base)) }
124124

125125
/// The inverse cosine of `x` in radians.
126-
public static func acos(_ x: Self) -> Self { .init(Array.acos(x.base)) }
126+
public static func acos(_ x: Self) -> Self { Self(Array.acos(x.base)) }
127127

128128
/// The inverse sine of `x` in radians.
129-
public static func asin(_ x: Self) -> Self { .init(Array.asin(x.base)) }
129+
public static func asin(_ x: Self) -> Self { Self(Array.asin(x.base)) }
130130

131131
/// The inverse tangent of `x` in radians.
132-
public static func atan(_ x: Self) -> Self { .init(Array.atan(x.base)) }
132+
public static func atan(_ x: Self) -> Self { Self(Array.atan(x.base)) }
133133

134134
/// The hyperbolic cosine of `x`.
135-
public static func cosh(_ x: Self) -> Self { .init(Array.cosh(x.base)) }
135+
public static func cosh(_ x: Self) -> Self { Self(Array.cosh(x.base)) }
136136

137137
/// The hyperbolic sine of `x`.
138-
public static func sinh(_ x: Self) -> Self { .init(Array.sinh(x.base)) }
138+
public static func sinh(_ x: Self) -> Self { Self(Array.sinh(x.base)) }
139139

140140
/// The hyperbolic tangent of `x`.
141-
public static func tanh(_ x: Self) -> Self { .init(Array.tanh(x.base)) }
141+
public static func tanh(_ x: Self) -> Self { Self(Array.tanh(x.base)) }
142142

143143
/// The inverse hyperbolic cosine of `x`.
144-
public static func acosh(_ x: Self) -> Self { .init(Array.acosh(x.base)) }
144+
public static func acosh(_ x: Self) -> Self { Self(Array.acosh(x.base)) }
145145

146146
/// The inverse hyperbolic sine of `x`.
147-
public static func asinh(_ x: Self) -> Self { .init(Array.asinh(x.base)) }
147+
public static func asinh(_ x: Self) -> Self { Self(Array.asinh(x.base)) }
148148

149149
/// The inverse hyperbolic tangent of `x`.
150-
public static func atanh(_ x: Self) -> Self { .init(Array.atanh(x.base)) }
150+
public static func atanh(_ x: Self) -> Self { Self(Array.atanh(x.base)) }
151151

152152
/// The exponential function applied to `x`, or `e**x`.
153-
public static func exp(_ x: Self) -> Self { .init(Array.exp(x.base)) }
153+
public static func exp(_ x: Self) -> Self { Self(Array.exp(x.base)) }
154154

155155
/// Two raised to to power `x`.
156-
public static func exp2(_ x: Self) -> Self { .init(Array.exp2(x.base)) }
156+
public static func exp2(_ x: Self) -> Self { Self(Array.exp2(x.base)) }
157157

158158
/// Ten raised to to power `x`.
159-
public static func exp10(_ x: Self) -> Self { .init(Array.exp10(x.base)) }
159+
public static func exp10(_ x: Self) -> Self { Self(Array.exp10(x.base)) }
160160

161161
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
162-
public static func expm1(_ x: Self) -> Self { .init(Array.expm1(x.base)) }
162+
public static func expm1(_ x: Self) -> Self { Self(Array.expm1(x.base)) }
163163

164164
/// The natural logarithm of `x`.
165-
public static func log(_ x: Self) -> Self { .init(Array.log(x.base)) }
165+
public static func log(_ x: Self) -> Self { Self(Array.log(x.base)) }
166166

167167
/// The base-two logarithm of `x`.
168-
public static func log2(_ x: Self) -> Self { .init(Array.log2(x.base)) }
168+
public static func log2(_ x: Self) -> Self { Self(Array.log2(x.base)) }
169169

170170
/// The base-ten logarithm of `x`.
171-
public static func log10(_ x: Self) -> Self { .init(Array.log10(x.base)) }
171+
public static func log10(_ x: Self) -> Self { Self(Array.log10(x.base)) }
172172

173173
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
174-
public static func log1p(_ x: Self) -> Self { .init(Array.log1p(x.base)) }
174+
public static func log1p(_ x: Self) -> Self { Self(Array.log1p(x.base)) }
175175

176176
/// `exp(y log(x))` computed without loss of intermediate precision.
177177
///
178178
/// For real types, if `x` is negative the result is NaN, even if `y` has
179179
/// an integral value. For complex types, there is a branch cut on the
180180
/// negative real axis.
181-
public static func pow(_ x: Self, _ y: Self) -> Self { .init(Array.pow(x.base, y.base)) }
181+
public static func pow(_ x: Self, _ y: Self) -> Self { Self(Array.pow(x.base, y.base)) }
182182

183183
/// `x` raised to the `n`th power.
184184
///
185185
/// The product of `n` copies of `x`.
186-
public static func pow(_ x: Self, _ n: Int) -> Self { .init(Array.pow(x.base, n)) }
186+
public static func pow(_ x: Self, _ n: Int) -> Self { Self(Array.pow(x.base, n)) }
187187

188188
/// The `n`th root of `x`.
189189
///
190190
/// For real types, if `x` is negative and `n` is even, the result is NaN.
191191
/// For complex types, there is a branch cut along the negative real axis.
192-
public static func root(_ x: Self, _ n: Int) -> Self { .init(Array.root(x.base, n)) }
192+
public static func root(_ x: Self, _ n: Int) -> Self { Self(Array.root(x.base, n)) }
193193
}
194194

195195
extension Array.DifferentiableView:
@@ -226,7 +226,7 @@ where Element: Differentiable & VectorProtocol {
226226
public typealias VectorSpaceScalar = Element.VectorSpaceScalar
227227

228228
public func adding(_ x: Element.VectorSpaceScalar) -> Array<Element>.DifferentiableView {
229-
.init(map { $0.adding(x) })
229+
Self(map { $0.adding(x) })
230230
}
231231

232232
public mutating func add(_ x: Element.VectorSpaceScalar) {
@@ -236,7 +236,7 @@ where Element: Differentiable & VectorProtocol {
236236
}
237237

238238
public func subtracting(_ x: Element.VectorSpaceScalar) -> Array<Element>.DifferentiableView {
239-
.init(map { $0.subtracting(x) })
239+
Self(map { $0.subtracting(x) })
240240
}
241241

242242
public mutating func subtract(_ x: Element.VectorSpaceScalar) {
@@ -246,7 +246,7 @@ where Element: Differentiable & VectorProtocol {
246246
}
247247

248248
public func scaled(by scale: Element.VectorSpaceScalar) -> Self {
249-
.init(map { $0.scaled(by: scale) })
249+
Self(map { $0.scaled(by: scale) })
250250
}
251251

252252
public mutating func scale(by scale: Element.VectorSpaceScalar) {
@@ -263,11 +263,11 @@ where Element: Differentiable & PointwiseMultiplicative {
263263
fatalError("One is not array-representable")
264264
}
265265

266-
public var reciprocal: Self { .init(map { $0.reciprocal }) }
266+
public var reciprocal: Self { Self(map { $0.reciprocal }) }
267267

268268
public static func .* (lhs: Self, rhs: Self) -> Self {
269269
precondition(lhs.count == rhs.count, "Count mismatch: \(lhs.count) and \(rhs.count)")
270-
return .init(zip(lhs, rhs).map(.*))
270+
return Self(zip(lhs, rhs).map(.*))
271271
}
272272

273273
public static func .*= (lhs: inout Self, rhs: Self) {
@@ -294,94 +294,103 @@ where Wrapped.TangentVector: ElementaryFunctions {
294294
///
295295
/// For real types, if `x` is negative the result is `.nan`. For complex
296296
/// types there is a branch cut on the negative real axis.
297-
public static func sqrt(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
297+
public static func sqrt(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sqrt)) }
298298

299299
/// The cosine of `x`, interpreted as an angle in radians.
300-
public static func cos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
300+
public static func cos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cos)) }
301301

302302
/// The sine of `x`, interpreted as an angle in radians.
303-
public static func sin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
303+
public static func sin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sin)) }
304304

305305
/// The tangent of `x`, interpreted as an angle in radians.
306-
public static func tan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
306+
public static func tan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tan)) }
307307

308308
/// The inverse cosine of `x` in radians.
309-
public static func acos(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
309+
public static func acos(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acos)) }
310310

311311
/// The inverse sine of `x` in radians.
312-
public static func asin(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
312+
public static func asin(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asin)) }
313313

314314
/// The inverse tangent of `x` in radians.
315-
public static func atan(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
315+
public static func atan(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atan)) }
316316

317317
/// The hyperbolic cosine of `x`.
318-
public static func cosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
318+
public static func cosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.cosh)) }
319319

320320
/// The hyperbolic sine of `x`.
321-
public static func sinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
321+
public static func sinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.sinh)) }
322322

323323
/// The hyperbolic tangent of `x`.
324-
public static func tanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
324+
public static func tanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.tanh)) }
325325

326326
/// The inverse hyperbolic cosine of `x`.
327-
public static func acosh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
327+
public static func acosh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.acosh)) }
328328

329329
/// The inverse hyperbolic sine of `x`.
330-
public static func asinh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
330+
public static func asinh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.asinh)) }
331331

332332
/// The inverse hyperbolic tangent of `x`.
333-
public static func atanh(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
333+
public static func atanh(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.atanh)) }
334334

335335
/// The exponential function applied to `x`, or `e**x`.
336-
public static func exp(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
336+
public static func exp(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp)) }
337337

338338
/// Two raised to to power `x`.
339-
public static func exp2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
339+
public static func exp2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp2)) }
340340

341341
/// Ten raised to to power `x`.
342-
public static func exp10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
342+
public static func exp10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.exp10)) }
343343

344344
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
345-
public static func expm1(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
345+
public static func expm1(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.expm1)) }
346346

347347
/// The natural logarithm of `x`.
348-
public static func log(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
348+
public static func log(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log)) }
349349

350350
/// The base-two logarithm of `x`.
351-
public static func log2(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
351+
public static func log2(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log2)) }
352352

353353
/// The base-ten logarithm of `x`.
354-
public static func log10(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
354+
public static func log10(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log10)) }
355355

356356
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
357-
public static func log1p(_ x: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
357+
public static func log1p(_ x: Self) -> Self { Self(x.value.map(Wrapped.TangentVector.log1p)) }
358358

359359
/// `exp(y log(x))` computed without loss of intermediate precision.
360360
///
361361
/// For real types, if `x` is negative the result is NaN, even if `y` has
362362
/// an integral value. For complex types, there is a branch cut on the
363363
/// negative real axis.
364-
public static func pow(_ x: Self, _ y: Self) -> Self { .init(x.value.map(Wrapped.TangentVector.sqrt)) }
364+
public static func pow(_ x: Self, _ y: Self) -> Self {
365+
switch (x.value, y.value) {
366+
case let (x?, y?): return Self(Wrapped.TangentVector.pow(x, y))
367+
default: return Self(nil)
368+
}
369+
}
365370

366371
/// `x` raised to the `n`th power.
367372
///
368373
/// The product of `n` copies of `x`.
369-
public static func pow(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.pow(x, n)})) }
374+
public static func pow(_ x: Self, _ n: Int) -> Self {
375+
Self(x.value.map({ x in Wrapped.TangentVector.pow(x, n) }))
376+
}
370377

371378
/// The `n`th root of `x`.
372379
///
373380
/// For real types, if `x` is negative and `n` is even, the result is NaN.
374381
/// For complex types, there is a branch cut along the negative real axis.
375-
public static func root(_ x: Self, _ n: Int) -> Self { .init(x.value.map({ x in Wrapped.TangentVector.root(x, n)})) }
382+
public static func root(_ x: Self, _ n: Int) -> Self {
383+
Self(x.value.map({ x in Wrapped.TangentVector.root(x, n) }))
384+
}
376385
}
377386

378387
extension Optional.TangentVector: PointwiseMultiplicative
379388
where Wrapped.TangentVector: PointwiseMultiplicative {
380389
public static var one: Self {
381-
.init(Wrapped.TangentVector.one)
390+
Self(Wrapped.TangentVector.one)
382391
}
383392

384-
public var reciprocal: Self { .init(value.map { $0.reciprocal }) }
393+
public var reciprocal: Self { Self(value.map { $0.reciprocal }) }
385394

386395
public static func .* (lhs: Self, rhs: Self) -> Self {
387396
switch (lhs.value, rhs.value) {
@@ -399,15 +408,17 @@ extension Optional.TangentVector: VectorProtocol
399408
where Wrapped.TangentVector: VectorProtocol {
400409
public typealias VectorSpaceScalar = Wrapped.TangentVector.VectorSpaceScalar
401410

402-
public func adding(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.adding(x) }) }
411+
public func adding(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.adding(x) }) }
403412

404413
public mutating func add(_ x: VectorSpaceScalar) { value?.add(x) }
405414

406-
public func subtracting(_ x: VectorSpaceScalar) -> Self { .init(value.map { $0.subtracting(x) }) }
415+
public func subtracting(_ x: VectorSpaceScalar) -> Self { Self(value.map { $0.subtracting(x) }) }
407416

408417
public mutating func subtract(_ x: VectorSpaceScalar) { value?.subtract(x) }
409418

410-
public func scaled(by scale: VectorSpaceScalar) -> Self { .init(value.map { $0.scaled(by: scale) }) }
419+
public func scaled(by scale: VectorSpaceScalar) -> Self {
420+
Self(value.map { $0.scaled(by: scale) })
421+
}
411422

412423
public mutating func scale(by scale: VectorSpaceScalar) {
413424
value?.scale(by: scale)

Tests/TensorFlowTests/TrivialModelTests.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ final class TrivialModelTests: XCTestCase {
5151
return meanSquaredError(predicted: ŷ, expected: y)
5252
}
5353
optimizer.update(&classifier, along: 𝛁model)
54-
dump(𝛁model)
5554
}
5655
}
5756
let ŷ = classifier.inferring(from: x)

0 commit comments

Comments
 (0)