@@ -37,6 +37,9 @@ fileprivate extension Optional {
37
37
38
38
/// A simple two layer dense net.
39
39
struct Net : Layer {
40
+ typealias Input = Tensor < Float >
41
+ typealias Output = Tensor < Float >
42
+
40
43
var l1 , l2 : Dense < Float >
41
44
42
45
init ( observationSize: Int , hiddenSize: Int , actionCount: Int ) {
@@ -48,7 +51,7 @@ struct Net: Layer {
48
51
}
49
52
50
53
@differentiable
51
- func applied ( to input: Tensor < Float > ) -> Tensor < Float > {
54
+ func call ( _ input: Input ) -> Output {
52
55
return input. sequenced ( through: l1, l2)
53
56
}
54
57
}
@@ -132,7 +135,7 @@ func nextBatch(
132
135
while true {
133
136
let observationPython = Tensor < Double > ( numpy: observationNumpy) . unwrapped ( )
134
137
let actionProbabilities =
135
- softmax ( net. applied ( to : Tensor ( observationPython) . reshaped ( to: [ 1 , 4 ] ) ) )
138
+ softmax ( net ( Tensor ( observationPython) . reshaped ( to: [ 1 , 4 ] ) ) )
136
139
let actionProbabilitiesPython = actionProbabilities [ 0 ] . makeNumpyArray ( )
137
140
let len = Python . len ( actionProbabilitiesPython)
138
141
assert ( actionCount == Int ( Python . len ( actionProbabilitiesPython) ) )
@@ -169,7 +172,7 @@ let actionCount = Int(env.action_space.n).unwrapped()
169
172
var net = Net ( observationSize: Int ( observationSize) , hiddenSize: hiddenSize, actionCount: actionCount)
170
173
// SGD optimizer reaches convergence with ~125 mini batches, while Adam uses ~25.
171
174
// let optimizer = SGD<Net, Float>(learningRate: 0.1, momentum: 0.9)
172
- let optimizer = Adam < Net , Float > ( learningRate: 0.01 )
175
+ let optimizer = Adam ( for : net , learningRate: 0.01 )
173
176
var batchIndex = 0
174
177
Context . local. learningPhase = . training
175
178
while true {
@@ -181,7 +184,7 @@ while true {
181
184
episodes: episodes, actionCount: actionCount)
182
185
183
186
let gradients = gradient ( at: net) { model -> Tensor < Float > in
184
- let logits = model. applied ( to : input)
187
+ let logits = model ( input)
185
188
let loss = softmaxCrossEntropy ( logits: logits, probabilities: target)
186
189
print ( " loss is \( loss) " )
187
190
return loss
0 commit comments