Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Continuing repository reorganization: extracting common CIFAR-10 and ResNet code #185

Merged
merged 3 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions CIFAR/Helpers.swift

This file was deleted.

18 changes: 0 additions & 18 deletions CIFAR/README.md

This file was deleted.

119 changes: 0 additions & 119 deletions CIFAR/ResNet.swift

This file was deleted.

53 changes: 20 additions & 33 deletions CIFAR/Data.swift → Datasets/CIFAR10/CIFAR10.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Original source:
// "The CIFAR-10 dataset"
// Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
// https://www.cs.toronto.edu/~kriz/cifar.html

import Foundation
import TensorFlow

#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

public struct CIFAR10 {
public let trainingDataset: Dataset<CIFARExample>
public let testDataset: Dataset<CIFARExample>

public init() {
self.trainingDataset = Dataset<CIFARExample>(elements: loadCIFARTrainingFiles())
self.testDataset = Dataset<CIFARExample>(elements: loadCIFARTestFile())
}
}

func downloadCIFAR10IfNotPresent(to directory: String = ".") {
let downloadPath = "\(directory)/cifar-10-batches-bin"
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
Expand Down Expand Up @@ -69,27 +84,7 @@ func downloadCIFAR10IfNotPresent(to directory: String = ".") {
print("Unarchiving completed")
}

struct Example: TensorGroup {
var label: Tensor<Int32>
var data: Tensor<Float>

init(label: Tensor<Int32>, data: Tensor<Float>) {
self.label = label
self.data = data
}

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 2)
let labelIndex = _handles.startIndex
let dataIndex = _handles.index(labelIndex, offsetBy: 1)
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}
}

func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExample {
downloadCIFAR10IfNotPresent(to: directory)
let path = "\(directory)/cifar-10-batches-bin/\(name)"

Expand Down Expand Up @@ -124,25 +119,17 @@ func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
let std = Tensor<Float>([0.229, 0.224, 0.225])
let imagesNormalized = ((imageTensor / 255.0) - mean) / std

return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
return CIFARExample(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
func loadCIFARTrainingFiles() -> CIFARExample {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return Example(
return CIFARExample(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

func loadCIFARTestFile() -> Example {
func loadCIFARTestFile() -> CIFARExample {
return loadCIFARFile(named: "test_batch.bin")
}

func loadCIFAR10() -> (
training: Dataset<Example>, test: Dataset<Example>
) {
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
return (training: trainingDataset, test: testDataset)
}
35 changes: 35 additions & 0 deletions Datasets/CIFAR10/CIFARExample.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

public struct CIFARExample: TensorGroup {
public var label: Tensor<Int32>
public var data: Tensor<Float>

public init(label: Tensor<Int32>, data: Tensor<Float>) {
self.label = label
self.data = data
}

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 2)
let labelIndex = _handles.startIndex
let dataIndex = _handles.index(labelIndex, offsetBy: 1)
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}
}
5 changes: 5 additions & 0 deletions Datasets/MNIST/MNIST.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Original source:
// "The MNIST database of handwritten digits"
// Yann LeCun, Corinna Cortes, and Christopher J.C. Burges
// http://yann.lecun.com/exdb/mnist/

import Foundation
import TensorFlow

Expand Down
File renamed without changes.
18 changes: 18 additions & 0 deletions Examples/Custom-CIFAR10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# CIFAR-10 with custom models

This example demonstrates how to train the custom-defined models (based on examples from [PyTorch](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) and [Keras](https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py) ) against the [CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).

Two custom models are defined, and one is applied to an instance of the CIFAR-10 dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.

## Setup

To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

To train the model, run:

```sh
cd swift-models
swift run -c release Custom-CIFAR10
```
17 changes: 10 additions & 7 deletions CIFAR/main.swift → Examples/Custom-CIFAR10/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import TensorFlow

let batchSize = 100

let cifarDataset = loadCIFAR10()
let testBatches = cifarDataset.test.batched(batchSize)
let dataset = CIFAR10()
let testBatches = dataset.testDataset.batched(batchSize)

var model = KerasModel()
let optimizer = RMSProp(for: model, learningRate: 0.0001, decay: 1e-6)
Expand All @@ -28,7 +29,7 @@ Context.local.learningPhase = .training
for epoch in 1...100 {
var trainingLossSum: Float = 0
var trainingBatchCount = 0
let trainingShuffled = cifarDataset.training.shuffled(
let trainingShuffled = dataset.trainingDataset.shuffled(
sampleCount: 50000, randomSeed: Int64(epoch))
for batch in trainingShuffled.batched(batchSize) {
let (labels, images) = (batch.label, batch.data)
Expand All @@ -52,15 +53,17 @@ for epoch in 1...100 {
testBatchCount += 1

let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
correctGuessCount = correctGuessCount +
Int(Tensor<Int32>(correctPredictions).sum().scalarized())
correctGuessCount = correctGuessCount + Int(
Tensor<Int32>(correctPredictions).sum().scalarized())
totalGuessCount = totalGuessCount + batchSize
}

let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
print("""
print(
"""
[Epoch \(epoch)] \
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
Loss: \(testLossSum / Float(testBatchCount))
""")
"""
)
}
Loading