Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Fix up models so they can run. #143

Merged
merged 3 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@

.build
*.xcodeproj
*.png
.DS_Store
cifar-10-batches-py/
2 changes: 1 addition & 1 deletion Autoencoder/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func plot(image: [Float], name: String) {

/// Reads a file into an array of bytes.
func readFile(_ filename: String) -> [UInt8] {
let possibleFolders = [".", "Resources"]
let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
for folder in possibleFolders {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(filename).path
Expand Down
16 changes: 13 additions & 3 deletions MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ import TensorFlow

/// Reads a file into an array of bytes.
func readFile(_ path: String) -> [UInt8] {
let url = URL(fileURLWithPath: path)
let data = try! Data(contentsOf: url, options: [])
return [UInt8](data)
let possibleFolders = [".", "MNIST"]
for folder in possibleFolders {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(path)
guard FileManager.default.fileExists(atPath: filePath.path) else {
continue
}
let data = try! Data(contentsOf: filePath, options: [])
return [UInt8](data)
}
fatalError("Filename not found: \(path)")
}

/// Reads MNIST images and labels from specified file paths.
Expand Down Expand Up @@ -76,6 +84,8 @@ let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
var classifier = Classifier()
let optimizer = RMSProp(for: classifier)

print("Beginning training...")

// The training loop.
for epoch in 1...epochCount {
var correctGuessCount = 0
Expand Down