Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 7d1ab04

Browse files
authored
Minor code formatting pass on Gym models. (#132)
1 parent 9c9b460 commit 7d1ab04

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

Gym/CartPole/main.swift

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ import TensorFlow
1818
let np = Python.import("numpy")
1919
let gym = Python.import("gym")
2020

21-
/// Model parameters and hyper parameters.
21+
/// Model parameters and hyperparameters.
2222
let hiddenSize = 128
2323
let batchSize = 16
2424
/// Controls the amount of good/long episodes to retain for training.
2525
let percentile = 70
2626

27-
// Force unwrapping with ! does not provide source location when unwrapping
28-
// nil, so we instead make a util function for debuggability.
27+
// Force unwrapping with `!` does not provide source location when unwrapping `nil`, so we instead
28+
// make a utility function for debuggability.
2929
fileprivate extension Optional {
3030
func unwrapped(file: StaticString = #file, line: UInt = #line) -> Wrapped {
3131
guard let unwrapped = self else {
@@ -43,11 +43,8 @@ struct Net: Layer {
4343
var l1, l2: Dense<Float>
4444

4545
init(observationSize: Int, hiddenSize: Int, actionCount: Int) {
46-
self.l1 = Dense<Float>(
47-
inputSize: observationSize, outputSize: hiddenSize, activation: relu)
48-
49-
self.l2 = Dense<Float>(
50-
inputSize: hiddenSize, outputSize: actionCount, activation: { $0 })
46+
l1 = Dense<Float>(inputSize: observationSize, outputSize: hiddenSize, activation: relu)
47+
l2 = Dense<Float>(inputSize: hiddenSize, outputSize: actionCount)
5148
}
5249

5350
@differentiable
@@ -69,15 +66,11 @@ struct Episode {
6966
let reward: Float
7067
}
7168

72-
/// Filter out bad/short episodes before we feed them as neural net training
73-
/// data.
69+
/// Filtering out bad/short episodes before we feed them as neural net training data.
7470
func filteringBatch(
7571
episodes: [Episode],
7672
actionCount: Int
77-
) -> (input: Tensor<Float>,
78-
target: Tensor<Float>,
79-
episodeCount: Int,
80-
meanReward: Float) {
73+
) -> (input: Tensor<Float>, target: Tensor<Float>, episodeCount: Int, meanReward: Float) {
8174
let rewards = episodes.map { $0.reward }
8275
let rewardBound = Float(np.percentile(rewards, percentile))!
8376
print("rewardBound = \(rewardBound)")
@@ -174,7 +167,7 @@ var net = Net(observationSize: Int(observationSize), hiddenSize: hiddenSize, act
174167
// let optimizer = SGD<Net, Float>(learningRate: 0.1, momentum: 0.9)
175168
let optimizer = Adam(for: net, learningRate: 0.01)
176169
var batchIndex = 0
177-
Context.local.learningPhase = .training
170+
178171
while true {
179172
print("Processing mini batch \(batchIndex)")
180173
batchIndex += 1
@@ -183,11 +176,13 @@ while true {
183176
let (input, target, episodeCount, meanReward) = filteringBatch(
184177
episodes: episodes, actionCount: actionCount)
185178

186-
let gradients = gradient(at: net) { model -> Tensor<Float> in
187-
let logits = model(input)
188-
let loss = softmaxCrossEntropy(logits: logits, probabilities: target)
189-
print("loss is \(loss)")
190-
return loss
179+
let gradients = withLearningPhase(.training) {
180+
net.gradient { net -> Tensor<Float> in
181+
let logits = net(input)
182+
let loss = softmaxCrossEntropy(logits: logits, probabilities: target)
183+
print("loss is \(loss)")
184+
return loss
185+
}
191186
}
192187
optimizer.update(&net.allDifferentiableVariables, along: gradients)
193188

Gym/FrozenLake/main.swift

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ import TensorFlow
1818
let np = Python.import("numpy")
1919
let gym = Python.import("gym")
2020

21-
// Solves the FrozenLake RL problem via Q-learning. This model does not use a
22-
// neural net, and instead demonstrates Swift host-side numeric processing as
23-
// well as Python integration.
21+
// Solves the FrozenLake RL problem via Q-learning. This model does not use a neural net, and
22+
// instead demonstrates Swift host-side numeric processing as well as Python integration.
2423

2524
let discountRate: Float = 0.9
2625
let learningRate: Float = 0.2
@@ -29,8 +28,8 @@ let testEpisodeCount = 20
2928
typealias State = Int
3029
typealias Action = Int
3130

32-
// Force unwrapping with ! does not provide source location when unwrapping
33-
// nil, so we instead make a util function for debuggability.
31+
// Force unwrapping with `!` does not provide source location when unwrapping `nil`, so we instead
32+
// make a utility function for debuggability.
3433
fileprivate extension Optional {
3534
func unwrapped(file: StaticString = #file, line: UInt = #line) -> Wrapped {
3635
guard let unwrapped = self else {
@@ -40,8 +39,8 @@ fileprivate extension Optional {
4039
}
4140
}
4241

43-
// This struct is defined so that `StateAction` can be a dictionary key
44-
// type. Swift tuples cannot conform to `Hashable`.
42+
// This struct is defined so that `StateAction` can be a dictionary key type. Swift tuples cannot
43+
// conform to `Hashable`.
4544
struct StateAction: Equatable, Hashable {
4645
let state: State
4746
let action: Action

0 commit comments

Comments
 (0)