Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 41c4a7a

Browse files
spencerkohanrxwei
authored andcommitted
Download MNIST dataset from remote URL (#215)
This PR is intended to fix the second item listed in [this issue](#206): > it currently relies on hardcoded paths, which will break if you take the dataset outside of that project and try to use it in something like Colab. In this PR, the datasets are instead downloaded from a remote URL, so they are not fragile with respect to the current working directory where the MNIST dataset is instantiated from. The implementation in this PR brings the MINST dataset in line with how the CFAR10 dataset is loaded.
1 parent 3b1f373 commit 41c4a7a

File tree

6 files changed

+150
-36
lines changed

6 files changed

+150
-36
lines changed

Datasets/DatasetUtilities.swift

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
#if canImport(FoundationNetworking)
18+
import FoundationNetworking
19+
#endif
20+
21+
public struct DatasetUtilities {
22+
public static let curentWorkingDirectoryURL = URL(
23+
fileURLWithPath: FileManager.default.currentDirectoryPath)
24+
25+
public static func fetchResource(
26+
filename: String,
27+
remoteRoot: URL,
28+
localStorageDirectory: URL = curentWorkingDirectoryURL
29+
) -> Data {
30+
print("Loading resource: \(filename)")
31+
32+
let resource = ResourceDefinition(
33+
filename: filename,
34+
remoteRoot: remoteRoot,
35+
localStorageDirectory: localStorageDirectory)
36+
37+
let localURL = resource.localURL
38+
39+
if !FileManager.default.fileExists(atPath: localURL.path) {
40+
print(
41+
"File does not exist locally at expected path: \(localURL.path) and must be fetched"
42+
)
43+
fetchFromRemoteAndSave(resource)
44+
}
45+
46+
do {
47+
print("Loading local data at: \(localURL.path)")
48+
let data = try Data(contentsOf: localURL)
49+
print("Succesfully loaded resource: \(filename)")
50+
return data
51+
} catch {
52+
fatalError("Failed to contents of resource: \(localURL)")
53+
}
54+
}
55+
56+
struct ResourceDefinition {
57+
let filename: String
58+
let remoteRoot: URL
59+
let localStorageDirectory: URL
60+
61+
var localURL: URL {
62+
localStorageDirectory.appendingPathComponent(filename)
63+
}
64+
65+
var remoteURL: URL {
66+
remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz")
67+
}
68+
69+
var archiveURL: URL {
70+
localURL.appendingPathExtension("gz")
71+
}
72+
}
73+
74+
static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
75+
let remoteLocation = resource.remoteURL
76+
let archiveLocation = resource.archiveURL
77+
78+
do {
79+
print("Fetching URL: \(remoteLocation)...")
80+
let archiveData = try Data(contentsOf: remoteLocation)
81+
print("Writing fetched archive to: \(archiveLocation.path)")
82+
try archiveData.write(to: archiveLocation)
83+
} catch {
84+
fatalError("Failed to fetch and save resource with error: \(error)")
85+
}
86+
print("Archive saved to: \(archiveLocation.path)")
87+
88+
extractArchive(for: resource)
89+
}
90+
91+
static func extractArchive(for resource: ResourceDefinition) {
92+
print("Extracting archive...")
93+
94+
let archivePath = resource.archiveURL.path
95+
96+
#if os(macOS)
97+
let gunzipLocation = "/usr/bin/gunzip"
98+
#else
99+
let gunzipLocation = "/bin/gunzip"
100+
#endif
101+
102+
let task = Process()
103+
task.executableURL = URL(fileURLWithPath: gunzipLocation)
104+
task.arguments = [archivePath]
105+
do {
106+
try task.run()
107+
task.waitUntilExit()
108+
} catch {
109+
fatalError("Failed to extract \(archivePath) with error: \(error)")
110+
}
111+
}
112+
}

Datasets/MNIST/MNIST.swift

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,27 @@ public struct MNIST {
3131

3232
public let batchSize: Int
3333

34-
public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) {
34+
public init(
35+
batchSize: Int, flattening: Bool = false, normalizing: Bool = false,
36+
localStorageDirectory: URL = DatasetUtilities.curentWorkingDirectoryURL
37+
) {
3538
self.batchSize = batchSize
3639

37-
let (trainingImages, trainingLabels) = readMNIST(
38-
imagesFile: "train-images-idx3-ubyte",
39-
labelsFile: "train-labels-idx1-ubyte",
40+
let (trainingImages, trainingLabels) = fetchDataset(
41+
localStorageDirectory: localStorageDirectory,
42+
imagesFilename: "train-images-idx3-ubyte",
43+
labelsFilename: "train-labels-idx1-ubyte",
4044
flattening: flattening,
4145
normalizing: normalizing)
46+
4247
self.trainingImages = trainingImages
4348
self.trainingLabels = trainingLabels
4449
self.trainingSize = Int(trainingLabels.shape[0])
4550

46-
let (testImages, testLabels) = readMNIST(
47-
imagesFile: "t10k-images-idx3-ubyte",
48-
labelsFile: "t10k-labels-idx1-ubyte",
51+
let (testImages, testLabels) = fetchDataset(
52+
localStorageDirectory: localStorageDirectory,
53+
imagesFilename: "t10k-images-idx3-ubyte",
54+
labelsFilename: "t10k-labels-idx1-ubyte",
4955
flattening: flattening,
5056
normalizing: normalizing)
5157
self.testImages = testImages
@@ -61,36 +67,31 @@ extension Tensor {
6167
}
6268
}
6369

64-
/// Reads a file into an array of bytes.
65-
func readFile(_ path: String, possibleDirectories: [String]) -> [UInt8] {
66-
for folder in possibleDirectories {
67-
let parent = URL(fileURLWithPath: folder)
68-
let filePath = parent.appendingPathComponent(path)
69-
guard FileManager.default.fileExists(atPath: filePath.path) else {
70-
continue
71-
}
72-
let data = try! Data(contentsOf: filePath, options: [])
73-
return [UInt8](data)
70+
fileprivate func fetchDataset(
71+
localStorageDirectory: URL,
72+
imagesFilename: String,
73+
labelsFilename: String,
74+
flattening: Bool,
75+
normalizing: Bool
76+
) -> (images: Tensor<Float>, labels: Tensor<Int32>) {
77+
guard let remoteRoot: URL = URL(string: "http://yann.lecun.com/exdb/mnist") else {
78+
fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist")
7479
}
75-
print("File not found: \(path)")
76-
exit(-1)
77-
}
7880

79-
/// Reads MNIST images and labels from specified file paths.
80-
func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> (
81-
images: Tensor<Float>,
82-
labels: Tensor<Int32>
83-
) {
84-
print("Reading data from files: \(imagesFile), \(labelsFile).")
85-
let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16)
86-
.map(Float.init)
87-
let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8)
88-
.map(Int32.init)
89-
let rowCount = labels.count
90-
let imageHeight = 28
91-
let imageWidth = 28
81+
let imagesData = DatasetUtilities.fetchResource(
82+
filename: imagesFilename,
83+
remoteRoot: remoteRoot,
84+
localStorageDirectory: localStorageDirectory)
85+
let labelsData = DatasetUtilities.fetchResource(
86+
filename: labelsFilename,
87+
remoteRoot: remoteRoot,
88+
localStorageDirectory: localStorageDirectory)
89+
90+
let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
91+
let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)
9292

93-
print("Constructing data tensors.")
93+
let rowCount = labels.count
94+
let (imageWidth, imageHeight) = (28, 28)
9495

9596
if flattening {
9697
var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
@@ -101,8 +102,9 @@ func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normali
101102
return (images: flattenedImages, labels: Tensor(labels))
102103
} else {
103104
return (
104-
images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
105-
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
105+
images:
106+
Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
107+
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
106108
labels: Tensor(labels)
107109
)
108110
}

Datasets/MNIST/t10k-images-idx3-ubyte

-7.48 MB
Binary file not shown.

Datasets/MNIST/t10k-labels-idx1-ubyte

-9.77 KB
Binary file not shown.
-44.9 MB
Binary file not shown.
-58.6 KB
Binary file not shown.

0 commit comments

Comments
 (0)