Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 84d741a

Browse files
authored
Fix up models so they can run. (#143)
As part of the migration to a single SwiftPM-based package for the models repository, the working directory where people run `swift run` has changed. As a result, a number of relative-path lookups were broken when actually running models. This commit fixes them up so they all properly run.
1 parent 51393cf commit 84d741a

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33

44
.build
55
*.xcodeproj
6+
*.png
67
.DS_Store
8+
cifar-10-batches-py/

Autoencoder/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func plot(image: [Float], name: String) {
4747

4848
/// Reads a file into an array of bytes.
4949
func readFile(_ filename: String) -> [UInt8] {
50-
let possibleFolders = [".", "Resources"]
50+
let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
5151
for folder in possibleFolders {
5252
let parent = URL(fileURLWithPath: folder)
5353
let filePath = parent.appendingPathComponent(filename).path

MNIST/main.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@ import TensorFlow
1717

1818
/// Reads a file into an array of bytes.
1919
func readFile(_ path: String) -> [UInt8] {
20-
let url = URL(fileURLWithPath: path)
21-
let data = try! Data(contentsOf: url, options: [])
22-
return [UInt8](data)
20+
let possibleFolders = [".", "MNIST"]
21+
for folder in possibleFolders {
22+
let parent = URL(fileURLWithPath: folder)
23+
let filePath = parent.appendingPathComponent(path)
24+
guard FileManager.default.fileExists(atPath: filePath.path) else {
25+
continue
26+
}
27+
let data = try! Data(contentsOf: filePath, options: [])
28+
return [UInt8](data)
29+
}
30+
fatalError("Filename not found: \(path)")
2331
}
2432

2533
/// Reads MNIST images and labels from specified file paths.
@@ -76,6 +84,8 @@ let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
7684
var classifier = Classifier()
7785
let optimizer = RMSProp(for: classifier)
7886

87+
print("Beginning training...")
88+
7989
// The training loop.
8090
for epoch in 1...epochCount {
8191
var correctGuessCount = 0

0 commit comments

Comments
 (0)