Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Download MNIST dataset from remote URL #215

Merged
merged 10 commits into from
Oct 15, 2019
Merged
112 changes: 112 additions & 0 deletions Datasets/DatasetUtilities.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

public struct DatasetUtilities {
public static let curentWorkingDirectoryURL = URL(
fileURLWithPath: FileManager.default.currentDirectoryPath)

public static func fetchResource(
filename: String,
remoteRoot: URL,
localStorageDirectory: URL = curentWorkingDirectoryURL
) -> Data {
print("Loading resource: \(filename)")

let resource = ResourceDefinition(
filename: filename,
remoteRoot: remoteRoot,
localStorageDirectory: localStorageDirectory)

let localURL = resource.localURL

if !FileManager.default.fileExists(atPath: localURL.path) {
print(
"File does not exist locally at expected path: \(localURL.path) and must be fetched"
)
fetchFromRemoteAndSave(resource)
}

do {
print("Loading local data at: \(localURL.path)")
let data = try Data(contentsOf: localURL)
print("Succesfully loaded resource: \(filename)")
return data
} catch {
fatalError("Failed to contents of resource: \(localURL)")
}
}

struct ResourceDefinition {
let filename: String
let remoteRoot: URL
let localStorageDirectory: URL

var localURL: URL {
localStorageDirectory.appendingPathComponent(filename)
}

var remoteURL: URL {
remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz")
}

var archiveURL: URL {
localURL.appendingPathExtension("gz")
}
}

static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
let remoteLocation = resource.remoteURL
let archiveLocation = resource.archiveURL

do {
print("Fetching URL: \(remoteLocation)...")
let archiveData = try Data(contentsOf: remoteLocation)
print("Writing fetched archive to: \(archiveLocation.path)")
try archiveData.write(to: archiveLocation)
} catch {
fatalError("Failed to fetch and save resource with error: \(error)")
}
print("Archive saved to: \(archiveLocation.path)")

extractArchive(for: resource)
}

static func extractArchive(for resource: ResourceDefinition) {
print("Extracting archive...")

let archivePath = resource.archiveURL.path

#if os(macOS)
let gunzipLocation = "/usr/bin/gunzip"
#else
let gunzipLocation = "/bin/gunzip"
#endif

let task = Process()
task.executableURL = URL(fileURLWithPath: gunzipLocation)
task.arguments = [archivePath]
do {
try task.run()
task.waitUntilExit()
} catch {
fatalError("Failed to extract \(archivePath) with error: \(error)")
}
}
}
74 changes: 38 additions & 36 deletions Datasets/MNIST/MNIST.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,27 @@ public struct MNIST {

public let batchSize: Int

public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) {
public init(
batchSize: Int, flattening: Bool = false, normalizing: Bool = false,
localStorageDirectory: URL = DatasetUtilities.curentWorkingDirectoryURL
) {
self.batchSize = batchSize

let (trainingImages, trainingLabels) = readMNIST(
imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte",
let (trainingImages, trainingLabels) = fetchDataset(
localStorageDirectory: localStorageDirectory,
imagesFilename: "train-images-idx3-ubyte",
labelsFilename: "train-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing)

self.trainingImages = trainingImages
self.trainingLabels = trainingLabels
self.trainingSize = Int(trainingLabels.shape[0])

let (testImages, testLabels) = readMNIST(
imagesFile: "t10k-images-idx3-ubyte",
labelsFile: "t10k-labels-idx1-ubyte",
let (testImages, testLabels) = fetchDataset(
localStorageDirectory: localStorageDirectory,
imagesFilename: "t10k-images-idx3-ubyte",
labelsFilename: "t10k-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing)
self.testImages = testImages
Expand All @@ -61,36 +67,31 @@ extension Tensor {
}
}

/// Reads a file into an array of bytes.
func readFile(_ path: String, possibleDirectories: [String]) -> [UInt8] {
for folder in possibleDirectories {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(path)
guard FileManager.default.fileExists(atPath: filePath.path) else {
continue
}
let data = try! Data(contentsOf: filePath, options: [])
return [UInt8](data)
fileprivate func fetchDataset(
localStorageDirectory: URL,
imagesFilename: String,
labelsFilename: String,
flattening: Bool,
normalizing: Bool
) -> (images: Tensor<Float>, labels: Tensor<Int32>) {
guard let remoteRoot: URL = URL(string: "http://yann.lecun.com/exdb/mnist") else {
fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist")
}
print("File not found: \(path)")
exit(-1)
}

/// Reads MNIST images and labels from specified file paths.
func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> (
images: Tensor<Float>,
labels: Tensor<Int32>
) {
print("Reading data from files: \(imagesFile), \(labelsFile).")
let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16)
.map(Float.init)
let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8)
.map(Int32.init)
let rowCount = labels.count
let imageHeight = 28
let imageWidth = 28
let imagesData = DatasetUtilities.fetchResource(
filename: imagesFilename,
remoteRoot: remoteRoot,
localStorageDirectory: localStorageDirectory)
let labelsData = DatasetUtilities.fetchResource(
filename: labelsFilename,
remoteRoot: remoteRoot,
localStorageDirectory: localStorageDirectory)

let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)

print("Constructing data tensors.")
let rowCount = labels.count
let (imageWidth, imageHeight) = (28, 28)

if flattening {
var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
Expand All @@ -101,8 +102,9 @@ func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normali
return (images: flattenedImages, labels: Tensor(labels))
} else {
return (
images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
images:
Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
labels: Tensor(labels)
)
}
Expand Down
Binary file removed Datasets/MNIST/t10k-images-idx3-ubyte
Binary file not shown.
Binary file removed Datasets/MNIST/t10k-labels-idx1-ubyte
Binary file not shown.
Binary file removed Datasets/MNIST/train-images-idx3-ubyte
Binary file not shown.
Binary file removed Datasets/MNIST/train-labels-idx1-ubyte
Binary file not shown.