12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ import Datasets
15
16
import Foundation
17
+ import ModelSupport
16
18
import TensorFlow
17
- import Python
18
- import Datasets
19
-
20
- // Import Python modules.
21
- let matplotlib = Python . import ( " matplotlib " )
22
- let np = Python . import ( " numpy " )
23
-
24
- // Use the AGG renderer for saving images to disk.
25
- matplotlib. use ( " Agg " )
26
-
27
- let plt = Python . import ( " matplotlib.pyplot " )
28
19
29
20
let epochCount = 10
30
21
let batchSize = 32
31
22
let outputFolder = " ./output/ "
32
- let imageHeight = 28 , imageWidth = 28
23
+ let imageHeight = 28
24
+ let imageWidth = 28
33
25
let imageSize = imageHeight * imageWidth
34
26
let latentSize = 64
35
27
36
- func plotImage( _ image: Tensor < Float > , name: String ) {
37
- // Create figure.
38
- let ax = plt. gca ( )
39
- let array = np. array ( [ image. scalars] )
40
- let pixels = array. reshape ( image. shape)
41
- if !FileManager. default. fileExists ( atPath: outputFolder) {
42
- try ! FileManager . default. createDirectory (
43
- atPath: outputFolder,
44
- withIntermediateDirectories: false ,
45
- attributes: nil )
46
- }
47
- ax. imshow ( pixels, cmap: " gray " )
48
- plt. savefig ( " \( outputFolder) \( name) .png " , dpi: 300 )
49
- plt. close ( )
50
- }
51
-
52
28
// Models
53
29
54
30
struct Generator : Layer {
55
- var dense1 = Dense < Float > ( inputSize: latentSize, outputSize: latentSize * 2 ,
56
- activation: { leakyRelu ( $0) } )
57
- var dense2 = Dense < Float > ( inputSize: latentSize * 2 , outputSize: latentSize * 4 ,
58
- activation: { leakyRelu ( $0) } )
59
- var dense3 = Dense < Float > ( inputSize: latentSize * 4 , outputSize: latentSize * 8 ,
60
- activation: { leakyRelu ( $0) } )
61
- var dense4 = Dense < Float > ( inputSize: latentSize * 8 , outputSize: imageSize,
62
- activation: tanh)
63
-
31
+ var dense1 = Dense < Float > (
32
+ inputSize: latentSize, outputSize: latentSize * 2 ,
33
+ activation: { leakyRelu ( $0) } )
34
+
35
+ var dense2 = Dense < Float > (
36
+ inputSize: latentSize * 2 , outputSize: latentSize * 4 ,
37
+ activation: { leakyRelu ( $0) } )
38
+
39
+ var dense3 = Dense < Float > (
40
+ inputSize: latentSize * 4 , outputSize: latentSize * 8 ,
41
+ activation: { leakyRelu ( $0) } )
42
+
43
+ var dense4 = Dense < Float > (
44
+ inputSize: latentSize * 8 , outputSize: imageSize,
45
+ activation: tanh)
46
+
64
47
var batchnorm1 = BatchNorm < Float > ( featureCount: latentSize * 2 )
65
48
var batchnorm2 = BatchNorm < Float > ( featureCount: latentSize * 4 )
66
49
var batchnorm3 = BatchNorm < Float > ( featureCount: latentSize * 8 )
67
-
50
+
68
51
@differentiable
69
52
func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
70
53
let x1 = batchnorm1 ( dense1 ( input) )
@@ -75,15 +58,22 @@ struct Generator: Layer {
75
58
}
76
59
77
60
struct Discriminator : Layer {
78
- var dense1 = Dense < Float > ( inputSize: imageSize, outputSize: 256 ,
79
- activation: { leakyRelu ( $0) } )
80
- var dense2 = Dense < Float > ( inputSize: 256 , outputSize: 64 ,
81
- activation: { leakyRelu ( $0) } )
82
- var dense3 = Dense < Float > ( inputSize: 64 , outputSize: 16 ,
83
- activation: { leakyRelu ( $0) } )
84
- var dense4 = Dense < Float > ( inputSize: 16 , outputSize: 1 ,
85
- activation: identity)
86
-
61
+ var dense1 = Dense < Float > (
62
+ inputSize: imageSize, outputSize: 256 ,
63
+ activation: { leakyRelu ( $0) } )
64
+
65
+ var dense2 = Dense < Float > (
66
+ inputSize: 256 , outputSize: 64 ,
67
+ activation: { leakyRelu ( $0) } )
68
+
69
+ var dense3 = Dense < Float > (
70
+ inputSize: 64 , outputSize: 16 ,
71
+ activation: { leakyRelu ( $0) } )
72
+
73
+ var dense4 = Dense < Float > (
74
+ inputSize: 16 , outputSize: 1 ,
75
+ activation: identity)
76
+
87
77
@differentiable
88
78
func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
89
79
input. sequenced ( through: dense1, dense2, dense3, dense4)
@@ -94,16 +84,19 @@ struct Discriminator: Layer {
94
84
95
85
@differentiable
96
86
func generatorLoss( fakeLogits: Tensor < Float > ) -> Tensor < Float > {
97
- sigmoidCrossEntropy ( logits: fakeLogits,
98
- labels: Tensor ( ones: fakeLogits. shape) )
87
+ sigmoidCrossEntropy (
88
+ logits: fakeLogits,
89
+ labels: Tensor ( ones: fakeLogits. shape) )
99
90
}
100
91
101
92
@differentiable
102
93
func discriminatorLoss( realLogits: Tensor < Float > , fakeLogits: Tensor < Float > ) -> Tensor < Float > {
103
- let realLoss = sigmoidCrossEntropy ( logits: realLogits,
104
- labels: Tensor ( ones: realLogits. shape) )
105
- let fakeLoss = sigmoidCrossEntropy ( logits: fakeLogits,
106
- labels: Tensor ( zeros: fakeLogits. shape) )
94
+ let realLoss = sigmoidCrossEntropy (
95
+ logits: realLogits,
96
+ labels: Tensor ( ones: realLogits. shape) )
97
+ let fakeLoss = sigmoidCrossEntropy (
98
+ logits: fakeLogits,
99
+ labels: Tensor ( zeros: fakeLogits. shape) )
107
100
return realLoss + fakeLoss
108
101
}
109
102
@@ -123,18 +116,28 @@ let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
123
116
// Noise vectors and plot function for testing
124
117
let testImageGridSize = 4
125
118
let testVector = sampleVector ( size: testImageGridSize * testImageGridSize)
126
- func plotTestImage( _ testImage: Tensor < Float > , name: String ) {
127
- var gridImage = testImage. reshaped ( to: [ testImageGridSize, testImageGridSize,
128
- imageHeight, imageWidth] )
119
+
120
+ func saveImageGrid( _ testImage: Tensor < Float > , name: String ) throws {
121
+ var gridImage = testImage. reshaped (
122
+ to: [
123
+ testImageGridSize, testImageGridSize,
124
+ imageHeight, imageWidth,
125
+ ] )
129
126
// Add padding.
130
127
gridImage = gridImage. padded ( forSizes: [ ( 0 , 0 ) , ( 0 , 0 ) , ( 1 , 1 ) , ( 1 , 1 ) ] , with: 1 )
131
128
// Transpose to create single image.
132
129
gridImage = gridImage. transposed ( withPermutations: [ 0 , 2 , 1 , 3 ] )
133
- gridImage = gridImage. reshaped ( to: [ ( imageHeight + 2 ) * testImageGridSize,
134
- ( imageWidth + 2 ) * testImageGridSize] )
130
+ gridImage = gridImage. reshaped (
131
+ to: [
132
+ ( imageHeight + 2 ) * testImageGridSize,
133
+ ( imageWidth + 2 ) * testImageGridSize,
134
+ ] )
135
135
// Convert [-1, 1] range to [0, 1] range.
136
136
gridImage = ( gridImage + 1 ) / 2
137
- plotImage ( gridImage, name: name)
137
+
138
+ try saveImage (
139
+ gridImage, size: ( gridImage. shape [ 0 ] , gridImage. shape [ 1 ] ) , directory: outputFolder,
140
+ name: name)
138
141
}
139
142
140
143
print ( " Start training... " )
@@ -147,20 +150,20 @@ for epoch in 1...epochCount {
147
150
// Perform alternative update.
148
151
// Update generator.
149
152
let vec1 = sampleVector ( size: batchSize)
150
-
153
+
151
154
let 𝛁generator = generator. gradient { generator - > Tensor< Float> in
152
155
let fakeImages = generator ( vec1)
153
156
let fakeLogits = discriminator ( fakeImages)
154
157
let loss = generatorLoss ( fakeLogits: fakeLogits)
155
158
return loss
156
159
}
157
160
optG. update ( & generator, along: 𝛁generator)
158
-
161
+
159
162
// Update discriminator.
160
163
let realImages = dataset. trainingImages. minibatch ( at: i, batchSize: batchSize)
161
164
let vec2 = sampleVector ( size: batchSize)
162
165
let fakeImages = generator ( vec2)
163
-
166
+
164
167
let 𝛁discriminator = discriminator. gradient { discriminator - > Tensor< Float> in
165
168
let realLogits = discriminator ( realImages)
166
169
let fakeLogits = discriminator ( fakeImages)
@@ -169,12 +172,17 @@ for epoch in 1...epochCount {
169
172
}
170
173
optD. update ( & discriminator, along: 𝛁discriminator)
171
174
}
172
-
175
+
173
176
// Start inference phase.
174
177
Context . local. learningPhase = . inference
175
178
let testImage = generator ( testVector)
176
- plotTestImage ( testImage, name: " epoch- \( epoch) -output " )
177
-
179
+
180
+ do {
181
+ try saveImageGrid ( testImage, name: " epoch- \( epoch) -output " )
182
+ } catch {
183
+ print ( " Could not save image grid with error: \( error) " )
184
+ }
185
+
178
186
let lossG = generatorLoss ( fakeLogits: testImage)
179
187
print ( " [Epoch: \( epoch) ] Loss-G: \( lossG) " )
180
188
}
0 commit comments