Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Convert MNIST to Epochs #497

Merged
merged 4 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Autoencoder/Autoencoder1D/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ let imageHeight = 28
let imageWidth = 28

let outputFolder = "./output/"
let dataset = FashionMNIST(batchSize: batchSize, flattening: true)
let dataset = OldFashionMNIST(batchSize: batchSize, flattening: true)
// An autoencoder.
var autoencoder = Sequential {
// The encoder.
Expand Down
2 changes: 1 addition & 1 deletion Autoencoder/Autoencoder2D/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ let imageHeight = 28
let imageWidth = 28

let outputFolder = "./output/"
let dataset = KuzushijiMNIST(batchSize: batchSize, flattening: true)
let dataset = OldKuzushijiMNIST(batchSize: batchSize, flattening: true)

// An autoencoder.
struct Autoencoder2D: Layer {
Expand Down
4 changes: 2 additions & 2 deletions Autoencoder/VAE1D/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ let imageHeight = 28
let imageWidth = 28

let outputFolder = "./output/"
let dataset = MNIST(batchSize: 128, flattening: true)
let dataset = OldMNIST(batchSize: 128, flattening: true)

let inputDim = 784 // 28*28 for any MNIST
let hiddenDim = 400
Expand Down Expand Up @@ -84,7 +84,7 @@ func vaeLossFunction(
}

// TODO: Find a cleaner way of extracting individual images that doesn't require a second dataset.
let singleImageDataset = MNIST(batchSize: 1, flattening: true)
let singleImageDataset = OldMNIST(batchSize: 1, flattening: true)
let individualTestImages = singleImageDataset.test
var testImageIterator = individualTestImages.sequenced()

Expand Down
4 changes: 2 additions & 2 deletions Benchmarks/Models/LeNetMnist.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ enum LeNetMNIST: BenchmarkModel {
}

static func makeInferenceBenchmark(settings: BenchmarkSettings) -> Benchmark {
return ImageClassificationInference<LeNet, MNIST>(settings: settings)
return ImageClassificationInference<LeNet, OldMNIST>(settings: settings)
}

static func makeTrainingBenchmark(settings: BenchmarkSettings) -> Benchmark {
return ImageClassificationTraining<LeNet, MNIST>(settings: settings)
return ImageClassificationTraining<LeNet, OldMNIST>(settings: settings)
}
}

Expand Down
2 changes: 1 addition & 1 deletion DCGAN/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import ModelSupport
import TensorFlow

let batchSize = 512
let mnist = MNIST(batchSize: batchSize, flattening: false, normalizing: true)
let mnist = OldMNIST(batchSize: batchSize, flattening: false, normalizing: true)

let outputFolder = "./output/"

Expand Down
21 changes: 12 additions & 9 deletions Datasets/CIFAR10/CIFAR10.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ public struct CIFAR10<Entropy: RandomNumberGenerator> {
self.init(
batchSize: batchSize,
entropy: entropy,
device: Device.default,
remoteBinaryArchiveLocation: URL(
string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!,
normalizing: true)
}

/// Creates an instance with `batchSize` using `remoteBinaryArchiveLocation`.
/// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`.
///
/// - Parameters:
/// - entropy: a source of randomness used to shuffle sample ordering. It
Expand All @@ -65,6 +66,7 @@ public struct CIFAR10<Entropy: RandomNumberGenerator> {
public init(
batchSize: Int,
entropy: Entropy,
device: Device,
remoteBinaryArchiveLocation: URL,
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
.appendingPathComponent("CIFAR10", isDirectory: true),
Expand All @@ -76,13 +78,13 @@ public struct CIFAR10<Entropy: RandomNumberGenerator> {
let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing) }
return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing, device: device) }
}

// Validation data
let validationSamples = loadCIFARTestFile(in: localStorageDirectory)
validation = validationSamples.inBatches(of: batchSize).lazy.map {
makeBatch(samples: $0, normalizing: normalizing)
makeBatch(samples: $0, normalizing: normalizing, device: device)
}
}
}
Expand Down Expand Up @@ -145,19 +147,20 @@ func loadCIFARTestFile(in localStorageDirectory: URL) -> [(data: [UInt8], label:
return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory)
}

func makeBatch<BatchSamples: Collection>(samples: BatchSamples, normalizing: Bool) -> LabeledImage
where BatchSamples.Element == (data: [UInt8], label: Int32) {
fileprivate func makeBatch<BatchSamples: Collection>(
samples: BatchSamples, normalizing: Bool, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes)
let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
imageTensor /= 255.0
if normalizing {
let mean = Tensor<Float>([0.4913996898, 0.4821584196, 0.4465309242])
let std = Tensor<Float>([0.2470322324, 0.2434851280, 0.2615878417])
let mean = Tensor<Float>([0.4913996898, 0.4821584196, 0.4465309242], on: device)
let std = Tensor<Float>([0.2470322324, 0.2434851280, 0.2615878417], on: device)
imageTensor = (imageTensor - mean) / std
}

let labels = Tensor<Int32>(samples.map(\.label))
return LabeledImage(data: imageTensor, label: labels)
}
}
4 changes: 4 additions & 0 deletions Datasets/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ add_library(Datasets
MNIST/MNIST.swift
MNIST/FashionMNIST.swift
MNIST/KuzushijiMNIST.swift
MNIST/OldMNISTDatasetHandler.swift
MNIST/OldMNIST.swift
MNIST/OldFashionMNIST.swift
MNIST/OldKuzushijiMNIST.swift
ObjectDetectionDataset.swift
BostonHousing/BostonHousing.swift
TextUnsupervised/TextUnsupervised.swift
Expand Down
106 changes: 72 additions & 34 deletions Datasets/MNIST/FashionMNIST.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,79 @@ import Foundation
import TensorFlow
import Batcher

public struct FashionMNIST: ImageClassificationDataset {
public typealias SourceDataSet = [TensorPair<Float, Int32>]
public let training: Batcher<SourceDataSet>
public let test: Batcher<SourceDataSet>
public struct FashionMNIST<Entropy: RandomNumberGenerator> {
/// Type of the collection of non-collated batches.
public typealias Batches = Slices<Sampling<[(data: [UInt8], label: Int32)], ArraySlice<Int>>>
/// The type of the training data, represented as a sequence of epochs, which
/// are collection of batches.
public typealias Training = LazyMapSequence<
TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>,
LazyMapSequence<Batches, LabeledImage>
>
/// The type of the validation data, represented as a collection of batches.
public typealias Validation = LazyMapSequence<Slices<[(data: [UInt8], label: Int32)]>, LabeledImage>
/// The training epochs.
public let training: Training
/// The validation batches.
public let validation: Validation

public init(batchSize: Int) {
self.init(batchSize: batchSize, flattening: false, normalizing: false)
}

public init(
batchSize: Int, flattening: Bool = false, normalizing: Bool = false,
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
.appendingPathComponent("FashionMNIST", isDirectory: true)
) {
training = Batcher<SourceDataSet>(
on: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
imagesFilename: "train-images-idx3-ubyte",
labelsFilename: "train-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing),
batchSize: batchSize,
numWorkers: 1, //No need to use parallelism since everything is loaded in memory
shuffle: true)
/// Creates an instance with `batchSize`.
///
/// - Parameter entropy: a source of randomness used to shuffle sample
/// ordering. It will be stored in `self`, so if it is only pseudorandom
/// and has value semantics, the sequence of epochs is deterministic and not
/// dependent on other operations.
public init(batchSize: Int, entropy: Entropy) {
self.init(batchSize: batchSize, device: Device.default, entropy: entropy,
flattening: false, normalizing: false)
}

test = Batcher<SourceDataSet>(
on: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
imagesFilename: "t10k-images-idx3-ubyte",
labelsFilename: "t10k-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing),
batchSize: batchSize,
numWorkers: 1) //No need to use parallelism since everything is loaded in memory
/// Creates an instance with `batchSize` on `device`.
///
/// - Parameters:
/// - entropy: a source of randomness used to shuffle sample ordering. It
/// will be stored in `self`, so if it is only pseudorandom and has value
/// semantics, the sequence of epochs is deterministic and not dependent
/// on other operations.
/// - flattening: flattens the data to be a 2d-tensor iff `true. The default value
/// is `false`.
/// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`.
/// The default value is `false`.
/// - localStorageDirectory: the directory in which the dataset is stored.
public init(
batchSize: Int, device: Device, entropy: Entropy, flattening: Bool = false,
normalizing: Bool = false,
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
.appendingPathComponent("FashionMNIST", isDirectory: true)
) {
training = TrainingEpochs(
samples: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
imagesFilename: "train-images-idx3-ubyte",
labelsFilename: "train-labels-idx1-ubyte"),
batchSize: batchSize, entropy: entropy
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
return batches.lazy.map{ makeMNISTBatch(
samples: $0, flattening: flattening, normalizing: normalizing, device: device
)}
}

validation = fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
imagesFilename: "t10k-images-idx3-ubyte",
labelsFilename: "t10k-labels-idx1-ubyte"
).inBatches(of: batchSize).lazy.map {
makeMNISTBatch(samples: $0, flattening: flattening, normalizing: normalizing,
device: device)
}
}
}

extension FashionMNIST: ImageClassificationData where Entropy == SystemRandomNumberGenerator {
/// Creates an instance with `batchSize`.
public init(batchSize: Int) {
self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator())
}
}
106 changes: 72 additions & 34 deletions Datasets/MNIST/KuzushijiMNIST.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,79 @@ import Foundation
import TensorFlow
import Batcher

public struct KuzushijiMNIST: ImageClassificationDataset {
public typealias SourceDataSet = [TensorPair<Float, Int32>]
public let training: Batcher<SourceDataSet>
public let test: Batcher<SourceDataSet>
public struct KuzushijiMNIST<Entropy: RandomNumberGenerator> {
/// Type of the collection of non-collated batches.
public typealias Batches = Slices<Sampling<[(data: [UInt8], label: Int32)], ArraySlice<Int>>>
/// The type of the training data, represented as a sequence of epochs, which
/// are collection of batches.
public typealias Training = LazyMapSequence<
TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>,
LazyMapSequence<Batches, LabeledImage>
>
/// The type of the validation data, represented as a collection of batches.
public typealias Validation = LazyMapSequence<Slices<[(data: [UInt8], label: Int32)]>, LabeledImage>
/// The training epochs.
public let training: Training
/// The validation batches.
public let validation: Validation

public init(batchSize: Int) {
self.init(batchSize: batchSize, flattening: false, normalizing: false)
}

public init(
batchSize: Int, flattening: Bool = false, normalizing: Bool = false,
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
.appendingPathComponent("KuzushijiMNIST", isDirectory: true)
) {
training = Batcher<SourceDataSet>(
on: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST",
imagesFilename: "train-images-idx3-ubyte",
labelsFilename: "train-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing),
batchSize: batchSize,
numWorkers: 1, //No need to use parallelism since everything is loaded in memory
shuffle: true)
/// Creates an instance with `batchSize`.
///
/// - Parameter entropy: a source of randomness used to shuffle sample
/// ordering. It will be stored in `self`, so if it is only pseudorandom
/// and has value semantics, the sequence of epochs is deterministic and not
/// dependent on other operations.
public init(batchSize: Int, entropy: Entropy) {
self.init(batchSize: batchSize, device: Device.default, entropy: entropy,
flattening: false, normalizing: false)
}

test = Batcher<SourceDataSet>(
on: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST",
imagesFilename: "t10k-images-idx3-ubyte",
labelsFilename: "t10k-labels-idx1-ubyte",
flattening: flattening,
normalizing: normalizing),
batchSize: batchSize,
numWorkers: 1) //No need to use parallelism since everything is loaded in memory
/// Creates an instance with `batchSize` on `device`.
///
/// - Parameters:
/// - entropy: a source of randomness used to shuffle sample ordering. It
/// will be stored in `self`, so if it is only pseudorandom and has value
/// semantics, the sequence of epochs is deterministic and not dependent
/// on other operations.
/// - flattening: flattens the data to be a 2d-tensor iff `true. The default value
/// is `false`.
/// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`.
/// The default value is `false`.
/// - localStorageDirectory: the directory in which the dataset is stored.
public init(
batchSize: Int, device: Device, entropy: Entropy, flattening: Bool = false,
normalizing: Bool = false,
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
.appendingPathComponent("KuzushijiMNIST", isDirectory: true)
) {
training = TrainingEpochs(
samples: fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST",
imagesFilename: "train-images-idx3-ubyte",
labelsFilename: "train-labels-idx1-ubyte"),
batchSize: batchSize, entropy: entropy
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
return batches.lazy.map{ makeMNISTBatch(
samples: $0, flattening: flattening, normalizing: normalizing, device: device
)}
}

validation = fetchMNISTDataset(
localStorageDirectory: localStorageDirectory,
remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST",
imagesFilename: "t10k-images-idx3-ubyte",
labelsFilename: "t10k-labels-idx1-ubyte"
).inBatches(of: batchSize).lazy.map {
makeMNISTBatch(samples: $0, flattening: flattening, normalizing: normalizing,
device: device)
}
}
}

extension KuzushijiMNIST: ImageClassificationData where Entropy == SystemRandomNumberGenerator {
/// Creates an instance with `batchSize`.
public init(batchSize: Int) {
self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator())
}
}
Loading