|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
| 15 | +import Datasets |
15 | 16 | import Foundation
|
| 17 | +import ModelSupport |
16 | 18 | import TensorFlow
|
17 |
| -import Python |
18 | 19 |
|
19 |
| -// Import Python modules |
20 |
| -let matplotlib = Python.import("matplotlib") |
21 |
| -let np = Python.import("numpy") |
22 |
| -let plt = Python.import("matplotlib.pyplot") |
23 |
| - |
24 |
| -// Turn off using display on server / linux |
25 |
| -matplotlib.use("Agg") |
26 |
| - |
27 |
| -// Some globals |
28 | 20 | let epochCount = 10
|
29 | 21 | let batchSize = 100
|
30 |
| -let outputFolder = "./output/" |
31 |
| -let imageHeight = 28, imageWidth = 28 |
32 |
| - |
33 |
| -func plot(image: [Float], name: String) { |
34 |
| - // Create figure |
35 |
| - let ax = plt.gca() |
36 |
| - let array = np.array([image]) |
37 |
| - let pixels = array.reshape([imageHeight, imageWidth]) |
38 |
| - if !FileManager.default.fileExists(atPath: outputFolder) { |
39 |
| - try! FileManager.default.createDirectory(atPath: outputFolder, |
40 |
| - withIntermediateDirectories: false, |
41 |
| - attributes: nil) |
42 |
| - } |
43 |
| - ax.imshow(pixels, cmap: "gray") |
44 |
| - plt.savefig("\(outputFolder)\(name).png", dpi: 300) |
45 |
| - plt.close() |
46 |
| -} |
| 22 | +let imageHeight = 28 |
| 23 | +let imageWidth = 28 |
47 | 24 |
|
48 |
| -/// Reads a file into an array of bytes. |
49 |
| -func readFile(_ filename: String) -> [UInt8] { |
50 |
| - let possibleFolders = [".", "Resources", "Autoencoder/Resources"] |
51 |
| - for folder in possibleFolders { |
52 |
| - let parent = URL(fileURLWithPath: folder) |
53 |
| - let filePath = parent.appendingPathComponent(filename).path |
54 |
| - guard FileManager.default.fileExists(atPath: filePath) else { |
55 |
| - continue |
56 |
| - } |
57 |
| - let d = Python.open(filePath, "rb").read() |
58 |
| - return Array(numpy: np.frombuffer(d, dtype: np.uint8))! |
59 |
| - } |
60 |
| - print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).") |
61 |
| - exit(-1) |
62 |
| -} |
63 |
| - |
64 |
| -/// Reads MNIST images and labels from specified file paths. |
65 |
| -func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>, |
66 |
| - labels: Tensor<Int32>) { |
67 |
| - print("Reading data.") |
68 |
| - let images = readFile(imagesFile).dropFirst(16).map { Float($0) } |
69 |
| - let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) } |
70 |
| - let rowCount = labels.count |
71 |
| - |
72 |
| - print("Constructing data tensors.") |
73 |
| - return ( |
74 |
| - images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0, |
75 |
| - labels: Tensor(labels) |
76 |
| - ) |
77 |
| -} |
78 |
| - |
79 |
| -/// An autoencoder. |
80 |
| -struct Autoencoder: Layer { |
81 |
| - typealias Input = Tensor<Float> |
82 |
| - typealias Output = Tensor<Float> |
83 |
| - |
84 |
| - var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128, |
85 |
| - activation: relu) |
86 |
| - var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu) |
87 |
| - var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu) |
88 |
| - var encoder4 = Dense<Float>(inputSize: 12, outputSize: 3, activation: relu) |
89 |
| - |
90 |
| - var decoder1 = Dense<Float>(inputSize: 3, outputSize: 12, activation: relu) |
91 |
| - var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu) |
92 |
| - var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu) |
93 |
| - var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth, |
94 |
| - activation: tanh) |
95 |
| - |
96 |
| - @differentiable |
97 |
| - func call(_ input: Input) -> Output { |
98 |
| - let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4) |
99 |
| - return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4) |
100 |
| - } |
101 |
| -} |
102 |
| - |
103 |
| -// MNIST data logic |
104 |
| -func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> { |
105 |
| - let start = index * batchSize |
106 |
| - return x[start..<start+batchSize] |
| 25 | +let outputFolder = "./output/" |
| 26 | +let dataset = MNIST(batchSize: batchSize, flattening: true) |
| 27 | +// An autoencoder. |
| 28 | +var autoencoder = Sequential { |
| 29 | + // The encoder. |
| 30 | + Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128, activation: relu) |
| 31 | + Dense<Float>(inputSize: 128, outputSize: 64, activation: relu) |
| 32 | + Dense<Float>(inputSize: 64, outputSize: 12, activation: relu) |
| 33 | + Dense<Float>(inputSize: 12, outputSize: 3, activation: relu) |
| 34 | + // The decoder. |
| 35 | + Dense<Float>(inputSize: 3, outputSize: 12, activation: relu) |
| 36 | + Dense<Float>(inputSize: 12, outputSize: 64, activation: relu) |
| 37 | + Dense<Float>(inputSize: 64, outputSize: 128, activation: relu) |
| 38 | + Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth, activation: tanh) |
107 | 39 | }
|
108 |
| - |
109 |
| -let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte", |
110 |
| - labelsFile: "train-labels-idx1-ubyte") |
111 |
| -let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10) |
112 |
| - |
113 |
| -var autoencoder = Autoencoder() |
114 | 40 | let optimizer = RMSProp(for: autoencoder)
|
115 | 41 |
|
116 | 42 | // Training loop
|
117 | 43 | for epoch in 1...epochCount {
|
118 |
| - let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars) |
| 44 | + let sampleImage = Tensor( |
| 45 | + shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars) |
119 | 46 | let testImage = autoencoder(sampleImage)
|
120 | 47 |
|
121 |
| - plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input") |
122 |
| - plot(image: testImage.scalars, name: "epoch-\(epoch)-output") |
| 48 | + do { |
| 49 | + try saveImage( |
| 50 | + sampleImage, size: (imageWidth, imageHeight), directory: outputFolder, |
| 51 | + name: "epoch-\(epoch)-input") |
| 52 | + try saveImage( |
| 53 | + testImage, size: (imageWidth, imageHeight), directory: outputFolder, |
| 54 | + name: "epoch-\(epoch)-output") |
| 55 | + } catch { |
| 56 | + print("Could not save image with error: \(error)") |
| 57 | + } |
123 | 58 |
|
124 | 59 | let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
|
125 | 60 | print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
|
126 | 61 |
|
127 |
| - for i in 0 ..< Int(labels.shape[0]) / batchSize { |
128 |
| - let x = minibatch(in: images, at: i) |
| 62 | + for i in 0 ..< dataset.trainingSize / batchSize { |
| 63 | + let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize) |
129 | 64 |
|
130 | 65 | let 𝛁model = autoencoder.gradient { autoencoder -> Tensor<Float> in
|
131 | 66 | let image = autoencoder(x)
|
132 | 67 | return meanSquaredError(predicted: image, expected: x)
|
133 | 68 | }
|
134 | 69 |
|
135 |
| - optimizer.update(&autoencoder.allDifferentiableVariables, along: 𝛁model) |
| 70 | + optimizer.update(&autoencoder, along: 𝛁model) |
136 | 71 | }
|
137 | 72 | }
|
0 commit comments