Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Replaced Python with Swift in CIFAR10 dataset loading #178

Merged
merged 6 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
.DS_Store
.swiftpm
cifar-10-batches-py/
cifar-10-batches-bin/
117 changes: 80 additions & 37 deletions CIFAR/Data.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,27 +12,61 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Python
import Foundation
import TensorFlow

/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

func downloadCIFAR10IfNotPresent(to directory: String = ".") {
let subprocess = Python.import("subprocess")
let path = Python.import("os.path")
let filepath = "\(directory)/cifar-10-batches-py"
let isdir = Bool(path.isdir(filepath))!
if !isdir {
print("Downloading CIFAR data...")
let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
subprocess.call(command, shell: true)
let downloadPath = "\(directory)/cifar-10-batches-bin"
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)

guard !directoryExists else { return }

print("Downloading CIFAR dataset...")
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
if !archiveExists {
print("Archive missing, downloading...")
do {
let downloadedFile = try Data(
contentsOf: URL(
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
} catch {
print("Could not download CIFAR dataset, error: \(error)")
exit(-1)
}
}

print("Archive downloaded, processing...")

#if os(macOS)
let tarLocation = "/usr/bin/tar"
#else
let tarLocation = "/bin/tar"
#endif

let task = Process()
task.executableURL = URL(fileURLWithPath: tarLocation)
task.arguments = ["xzf", archivePath]
do {
try task.run()
task.waitUntilExit()
} catch {
print("CIFAR extraction failed with error: \(error)")
}
}

extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
public var _tfeTensorHandle: _AnyTensorHandle {
TFETensorHandle(_owning: handle._cTensorHandle)
do {
try FileManager.default.removeItem(atPath: archivePath)
} catch {
print("Could not remove archive, error: \(error)")
exit(-1)
}

print("Unarchiving completed")
}

struct Example: TensorGroup {
Expand All @@ -53,52 +87,61 @@ struct Example: TensorGroup {
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}

public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
downloadCIFAR10IfNotPresent(to: directory)
let np = Python.import("numpy")
let pickle = Python.import("pickle")
let path = "\(directory)/cifar-10-batches-py/\(name)"
let f = Python.open(path, "rb")
let res = pickle.load(f, encoding: "bytes")
let path = "\(directory)/cifar-10-batches-bin/\(name)"

let bytes = res[Python.bytes("data", encoding: "utf8")]
let labels = res[Python.bytes("labels", encoding: "utf8")]
let imageCount = 10000
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
print("Could not read dataset file: \(name)")
exit(-1)
}
guard fileContents.count == 30_730_000 else {
print(
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
exit(-1)
}

var bytes: [UInt8] = []
var labels: [Int64] = []

let imageByteSize = 3073
for imageIndex in 0..<imageCount {
let baseAddress = imageIndex * imageByteSize
labels.append(Int64(fileContents[baseAddress]))
bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
}

let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
let images = Tensor<UInt8>(numpy: bytes)!
let imageCount = images.shape[0]
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

// reshape and transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(images
.reshaped(to: [imageCount, 3, 32, 32])
.transposed(withPermutations: [0, 2, 3, 1]))
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
let imageTensor = Tensor<Float>(images.transposed(withPermutations: [0, 2, 3, 1]))

let mean = Tensor<Float>([0.485, 0.456, 0.406])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let imagesNormalized = ((imageTensor / 255.0) - mean) / std

return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return Example(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

func loadCIFARTestFile() -> Example {
return loadCIFARFile(named: "test_batch")
return loadCIFARFile(named: "test_batch.bin")
}

func loadCIFAR10() -> (
training: Dataset<Example>, test: Dataset<Example>) {
training: Dataset<Example>, test: Dataset<Example>
) {
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
return (training: trainingDataset, test: testDataset)
Expand Down
7 changes: 1 addition & 6 deletions CIFAR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@ classification on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) da
## Setup

You'll need [the latest version][INSTALL] of Swift for TensorFlow
installed and added to your path. Additionally, the data loader requires Python
3.x (rather than Python 2.7), `wget`, and `numpy`.

> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
> `homebrew`.
installed and added to your path.

To train the default model, run:

Expand Down
2 changes: 0 additions & 2 deletions CIFAR/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
// limitations under the License.

import TensorFlow
import Python
PythonLibrary.useVersion(3)

let batchSize = 100

Expand Down
5 changes: 4 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// swift-tools-version:4.2
// swift-tools-version:5.0
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "TensorFlowModels",
platforms: [
.macOS(.v10_13)
],
products: [
.executable(name: "MNIST", targets: ["MNIST"]),
.executable(name: "CIFAR", targets: ["CIFAR"]),
Expand Down
119 changes: 80 additions & 39 deletions ResNet/Data.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,61 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Python
import Foundation
import TensorFlow

/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

func downloadCIFAR10IfNotPresent(to directory: String = ".") {
let subprocess = Python.import("subprocess")
let path = Python.import("os.path")
let filepath = "\(directory)/cifar-10-batches-py"
let isdir = Bool(path.isdir(filepath))!
if !isdir {
print("Downloading CIFAR data...")
let command = """
wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - \
-C \(directory)
"""
subprocess.call(command, shell: true)
let downloadPath = "\(directory)/cifar-10-batches-bin"
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)

guard !directoryExists else { return }

print("Downloading CIFAR dataset...")
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
if !archiveExists {
print("Archive missing, downloading...")
do {
let downloadedFile = try Data(
contentsOf: URL(
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
} catch {
print("Could not download CIFAR dataset, error: \(error)")
exit(-1)
}
}

print("Archive downloaded, processing...")

#if os(macOS)
let tarLocation = "/usr/bin/tar"
#else
let tarLocation = "/bin/tar"
#endif

let task = Process()
task.executableURL = URL(fileURLWithPath: tarLocation)
task.arguments = ["xzf", archivePath]
do {
try task.run()
task.waitUntilExit()
} catch {
print("CIFAR extraction failed with error: \(error)")
}
}

extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
public var _tfeTensorHandle: _AnyTensorHandle {
TFETensorHandle(_owning: handle._cTensorHandle)
do {
try FileManager.default.removeItem(atPath: archivePath)
} catch {
print("Could not remove archive, error: \(error)")
exit(-1)
}

print("Unarchiving completed")
}

struct Example: TensorGroup {
Expand All @@ -56,51 +87,61 @@ struct Example: TensorGroup {
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}

public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
downloadCIFAR10IfNotPresent(to: directory)
let np = Python.import("numpy")
let pickle = Python.import("pickle")
let path = "\(directory)/cifar-10-batches-py/\(name)"
let f = Python.open(path, "rb")
let res = pickle.load(f, encoding: "bytes")
let path = "\(directory)/cifar-10-batches-bin/\(name)"

let bytes = res[Python.bytes("data", encoding: "utf8")]
let labels = res[Python.bytes("labels", encoding: "utf8")]
let imageCount = 10000
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
print("Could not read dataset file: \(name)")
exit(-1)
}
guard fileContents.count == 30_730_000 else {
print(
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
exit(-1)
}

var bytes: [UInt8] = []
var labels: [Int64] = []

let imageByteSize = 3073
for imageIndex in 0..<imageCount {
let baseAddress = imageIndex * imageByteSize
labels.append(Int64(fileContents[baseAddress]))
bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
}

let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
let images = Tensor<UInt8>(numpy: bytes)!
let imageCount = images.shape[0]
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

// reshape and transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(images
.reshaped(to: [imageCount, 3, 32, 32])
.transposed(withPermutations: [0, 2, 3, 1]))
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
let imageTensor = Tensor<Float>(images.transposed(withPermutations: [0, 2, 3, 1]))

let mean = Tensor<Float>([0.485, 0.456, 0.406])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let imagesNormalized = ((imageTensor / 255.0) - mean) / std

return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return Example(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

func loadCIFARTestFile() -> Example {
return loadCIFARFile(named: "test_batch")
return loadCIFARFile(named: "test_batch.bin")
}

func loadCIFAR10() -> (training: Dataset<Example>, test: Dataset<Example>) {
func loadCIFAR10() -> (
training: Dataset<Example>, test: Dataset<Example>
) {
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
return (training: trainingDataset, test: testDataset)
Expand Down
9 changes: 2 additions & 7 deletions ResNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@ dataset.
## Setup

You'll need [the latest version][INSTALL] of Swift for TensorFlow
installed and added to your path. Additionally, the data loader requires Python
3.x (rather than Python 2.7), `wget`, and `numpy`.

> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
> `homebrew`.
installed and added to your path.

To train the model on CIFAR-10, run:

```
cd swift-models
swift run ResNet
swift run -c release ResNet
```

[INSTALL]: (https://github.com/tensorflow/swift/blob/master/Installation.md)
Loading