Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 21c694d

Browse files
Shashi456rxwei
authored andcommitted
Convert Autoencoder and Catch to use Sequential (#203)
Partially fixes #202.
1 parent 036f014 commit 21c694d

File tree

2 files changed

+21
-45
lines changed

2 files changed

+21
-45
lines changed

Autoencoder/main.swift

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,20 @@ let imageHeight = 28
2323
let imageWidth = 28
2424

2525
let outputFolder = "./output/"
26-
27-
/// An autoencoder.
28-
struct Autoencoder: Layer {
29-
var encoder1 = Dense<Float>(
30-
inputSize: imageHeight * imageWidth, outputSize: 128,
31-
activation: relu)
32-
33-
var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
34-
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
35-
var encoder4 = Dense<Float>(inputSize: 12, outputSize: 3, activation: relu)
36-
37-
var decoder1 = Dense<Float>(inputSize: 3, outputSize: 12, activation: relu)
38-
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
39-
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
40-
41-
var decoder4 = Dense<Float>(
42-
inputSize: 128, outputSize: imageHeight * imageWidth,
43-
activation: tanh)
44-
45-
@differentiable
46-
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
47-
let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4)
48-
return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4)
49-
}
50-
}
51-
5226
let dataset = MNIST(batchSize: batchSize, flattening: true)
53-
var autoencoder = Autoencoder()
27+
// An autoencoder.
28+
var autoencoder = Sequential {
29+
// The encoder.
30+
Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128, activation: relu)
31+
Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
32+
Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
33+
Dense<Float>(inputSize: 12, outputSize: 3, activation: relu)
34+
// The decoder.
35+
Dense<Float>(inputSize: 3, outputSize: 12, activation: relu)
36+
Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
37+
Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
38+
Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth, activation: tanh)
39+
}
5440
let optimizer = RMSProp(for: autoencoder)
5541

5642
// Training loop

Catch/main.swift

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,20 @@ protocol Agent: AnyObject {
4343
func step(observation: Observation, reward: Reward) -> Action
4444
}
4545

46-
struct Model: Layer {
47-
typealias Input = Tensor<Float>
48-
typealias Output = Tensor<Float>
49-
50-
var layer1 = Dense<Float>(inputSize: 3, outputSize: 50, activation: sigmoid,
51-
generator: &rng)
52-
var layer2 = Dense<Float>(inputSize: 50, outputSize: 3, activation: sigmoid,
53-
generator: &rng)
54-
55-
@differentiable
56-
func callAsFunction(_ input: Input) -> Output {
57-
return input.sequenced(through: layer1, layer2)
58-
}
59-
}
60-
6146
class CatchAgent: Agent {
6247
typealias Action = CatchAction
6348

64-
var model: Model = Model()
65-
let optimizer: Adam<Model>
49+
var model = Sequential {
50+
Dense<Float>(inputSize: 3, outputSize: 50, activation: sigmoid, generator: &rng)
51+
Dense<Float>(inputSize: 50, outputSize: 3, activation: sigmoid, generator: &rng)
52+
}
53+
54+
var learningRate: Float
55+
lazy var optimizer = Adam(for: self.model, learningRate: self.learningRate)
6656
var previousReward: Reward
6757

6858
init(initialReward: Reward, learningRate: Float) {
69-
optimizer = Adam(for: model, learningRate: learningRate)
59+
self.learningRate = learningRate
7060
previousReward = initialReward
7161
}
7262
}

0 commit comments

Comments
 (0)