Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 9e3ac12

Browse files
brettkooncerxwei
authored andcommitted
convert catch/gym examples to new layers api (#112)
1 parent 643b182 commit 9e3ac12

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

Catch/catch.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ protocol Agent: AnyObject {
4444
}
4545

4646
struct Model: Layer {
47+
typealias Input = Tensor<Float>
48+
typealias Output = Tensor<Float>
49+
4750
var layer1 = Dense<Float>(inputSize: 3, outputSize: 50, activation: sigmoid,
4851
generator: &rng)
4952
var layer2 = Dense<Float>(inputSize: 50, outputSize: 3, activation: sigmoid,
5053
generator: &rng)
5154

5255
@differentiable
53-
func applied(to input: Tensor<Float>) -> Tensor<Float> {
56+
func call(_ input: Input) -> Output {
5457
return input.sequenced(through: layer1, layer2)
5558
}
5659
}
@@ -59,11 +62,11 @@ class CatchAgent: Agent {
5962
typealias Action = CatchAction
6063

6164
var model: Model = Model()
62-
let optimizer: Adam<Model, Float>
65+
let optimizer: Adam<Model>
6366
var previousReward: Reward
6467

6568
init(initialReward: Reward, learningRate: Float) {
66-
optimizer = Adam(learningRate: learningRate)
69+
optimizer = Adam(for: model, learningRate: learningRate)
6770
previousReward = initialReward
6871
}
6972
}

Gym/CartPole.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ fileprivate extension Optional {
3737

3838
/// A simple two layer dense net.
3939
struct Net: Layer {
40+
typealias Input = Tensor<Float>
41+
typealias Output = Tensor<Float>
42+
4043
var l1, l2: Dense<Float>
4144

4245
init(observationSize: Int, hiddenSize: Int, actionCount: Int) {
@@ -48,7 +51,7 @@ struct Net: Layer {
4851
}
4952

5053
@differentiable
51-
func applied(to input: Tensor<Float>) -> Tensor<Float> {
54+
func call(_ input: Input) -> Output {
5255
return input.sequenced(through: l1, l2)
5356
}
5457
}
@@ -132,7 +135,7 @@ func nextBatch(
132135
while true {
133136
let observationPython = Tensor<Double>(numpy: observationNumpy).unwrapped()
134137
let actionProbabilities =
135-
softmax(net.applied(to: Tensor(observationPython).reshaped(to: [1, 4])))
138+
softmax(net(Tensor(observationPython).reshaped(to: [1, 4])))
136139
let actionProbabilitiesPython = actionProbabilities[0].makeNumpyArray()
137140
let len = Python.len(actionProbabilitiesPython)
138141
assert(actionCount == Int(Python.len(actionProbabilitiesPython)))
@@ -169,7 +172,7 @@ let actionCount = Int(env.action_space.n).unwrapped()
169172
var net = Net(observationSize: Int(observationSize), hiddenSize: hiddenSize, actionCount: actionCount)
170173
// SGD optimizer reaches convergence with ~125 mini batches, while Adam uses ~25.
171174
// let optimizer = SGD<Net, Float>(learningRate: 0.1, momentum: 0.9)
172-
let optimizer = Adam<Net, Float>(learningRate: 0.01)
175+
let optimizer = Adam(for: net, learningRate: 0.01)
173176
var batchIndex = 0
174177
Context.local.learningPhase = .training
175178
while true {
@@ -181,7 +184,7 @@ while true {
181184
episodes: episodes, actionCount: actionCount)
182185

183186
let gradients = gradient(at: net) { model -> Tensor<Float> in
184-
let logits = model.applied(to: input)
187+
let logits = model(input)
185188
let loss = softmaxCrossEntropy(logits: logits, probabilities: target)
186189
print("loss is \(loss)")
187190
return loss

0 commit comments

Comments
 (0)