Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 08c80a5

Browse files
authored
Replaced Python with Swift in CIFAR10 dataset loading (#178)
* Replaced all Python code in the CIFAR10 and ResNet examples, removing Python 3 dependency. * Needed to import FoundationNetworking on Linux. * Added a check for FoundationNetworking, added an early exit for cached directory check. * Removed macOS availability check by targeting 10.13 in the package. * Style and formatting fixes. * Removed no-longer-needed _tensorHandles and supporting code.
1 parent 2fa11ba commit 08c80a5

File tree

8 files changed

+168
-94
lines changed

8 files changed

+168
-94
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
.DS_Store
88
.swiftpm
99
cifar-10-batches-py/
10+
cifar-10-batches-bin/

CIFAR/Data.swift

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,27 +12,61 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Python
15+
import Foundation
1616
import TensorFlow
1717

18-
/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
19-
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
18+
#if canImport(FoundationNetworking)
19+
import FoundationNetworking
20+
#endif
21+
2022
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
21-
let subprocess = Python.import("subprocess")
22-
let path = Python.import("os.path")
23-
let filepath = "\(directory)/cifar-10-batches-py"
24-
let isdir = Bool(path.isdir(filepath))!
25-
if !isdir {
26-
print("Downloading CIFAR data...")
27-
let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
28-
subprocess.call(command, shell: true)
23+
let downloadPath = "\(directory)/cifar-10-batches-bin"
24+
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
25+
26+
guard !directoryExists else { return }
27+
28+
print("Downloading CIFAR dataset...")
29+
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
30+
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
31+
if !archiveExists {
32+
print("Archive missing, downloading...")
33+
do {
34+
let downloadedFile = try Data(
35+
contentsOf: URL(
36+
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
37+
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
38+
} catch {
39+
print("Could not download CIFAR dataset, error: \(error)")
40+
exit(-1)
41+
}
42+
}
43+
44+
print("Archive downloaded, processing...")
45+
46+
#if os(macOS)
47+
let tarLocation = "/usr/bin/tar"
48+
#else
49+
let tarLocation = "/bin/tar"
50+
#endif
51+
52+
let task = Process()
53+
task.executableURL = URL(fileURLWithPath: tarLocation)
54+
task.arguments = ["xzf", archivePath]
55+
do {
56+
try task.run()
57+
task.waitUntilExit()
58+
} catch {
59+
print("CIFAR extraction failed with error: \(error)")
2960
}
30-
}
3161

32-
extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
33-
public var _tfeTensorHandle: _AnyTensorHandle {
34-
TFETensorHandle(_owning: handle._cTensorHandle)
62+
do {
63+
try FileManager.default.removeItem(atPath: archivePath)
64+
} catch {
65+
print("Could not remove archive, error: \(error)")
66+
exit(-1)
3567
}
68+
69+
print("Unarchiving completed")
3670
}
3771

3872
struct Example: TensorGroup {
@@ -53,52 +87,61 @@ struct Example: TensorGroup {
5387
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
5488
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
5589
}
56-
57-
public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
5890
}
5991

60-
// Each CIFAR data file is provided as a Python pickle of NumPy arrays
6192
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
6293
downloadCIFAR10IfNotPresent(to: directory)
63-
let np = Python.import("numpy")
64-
let pickle = Python.import("pickle")
65-
let path = "\(directory)/cifar-10-batches-py/\(name)"
66-
let f = Python.open(path, "rb")
67-
let res = pickle.load(f, encoding: "bytes")
94+
let path = "\(directory)/cifar-10-batches-bin/\(name)"
6895

69-
let bytes = res[Python.bytes("data", encoding: "utf8")]
70-
let labels = res[Python.bytes("labels", encoding: "utf8")]
96+
let imageCount = 10000
97+
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
98+
print("Could not read dataset file: \(name)")
99+
exit(-1)
100+
}
101+
guard fileContents.count == 30_730_000 else {
102+
print(
103+
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
104+
exit(-1)
105+
}
106+
107+
var bytes: [UInt8] = []
108+
var labels: [Int64] = []
109+
110+
let imageByteSize = 3073
111+
for imageIndex in 0..<imageCount {
112+
let baseAddress = imageIndex * imageByteSize
113+
labels.append(Int64(fileContents[baseAddress]))
114+
bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
115+
}
71116

72-
let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
73-
let images = Tensor<UInt8>(numpy: bytes)!
74-
let imageCount = images.shape[0]
117+
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
118+
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)
75119

76-
// reshape and transpose from the provided N(CHW) to TF default NHWC
77-
let imageTensor = Tensor<Float>(images
78-
.reshaped(to: [imageCount, 3, 32, 32])
79-
.transposed(withPermutations: [0, 2, 3, 1]))
120+
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
121+
let imageTensor = Tensor<Float>(images.transposed(withPermutations: [0, 2, 3, 1]))
80122

81123
let mean = Tensor<Float>([0.485, 0.456, 0.406])
82-
let std = Tensor<Float>([0.229, 0.224, 0.225])
124+
let std = Tensor<Float>([0.229, 0.224, 0.225])
83125
let imagesNormalized = ((imageTensor / 255.0) - mean) / std
84126

85127
return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
86128
}
87129

88130
func loadCIFARTrainingFiles() -> Example {
89-
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
131+
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
90132
return Example(
91133
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
92134
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
93135
)
94136
}
95137

96138
func loadCIFARTestFile() -> Example {
97-
return loadCIFARFile(named: "test_batch")
139+
return loadCIFARFile(named: "test_batch.bin")
98140
}
99141

100142
func loadCIFAR10() -> (
101-
training: Dataset<Example>, test: Dataset<Example>) {
143+
training: Dataset<Example>, test: Dataset<Example>
144+
) {
102145
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
103146
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
104147
return (training: trainingDataset, test: testDataset)

CIFAR/README.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@ classification on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) da
66
## Setup
77

88
You'll need [the latest version][INSTALL] of Swift for TensorFlow
9-
installed and added to your path. Additionally, the data loader requires Python
10-
3.x (rather than Python 2.7), `wget`, and `numpy`.
11-
12-
> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
13-
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
14-
> `homebrew`.
9+
installed and added to your path.
1510

1611
To train the default model, run:
1712

CIFAR/main.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
// limitations under the License.
1414

1515
import TensorFlow
16-
import Python
17-
PythonLibrary.useVersion(3)
1816

1917
let batchSize = 100
2018

Package.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
// swift-tools-version:4.2
1+
// swift-tools-version:5.0
22
// The swift-tools-version declares the minimum version of Swift required to build this package.
33

44
import PackageDescription
55

66
let package = Package(
77
name: "TensorFlowModels",
8+
platforms: [
9+
.macOS(.v10_13)
10+
],
811
products: [
912
.executable(name: "MNIST", targets: ["MNIST"]),
1013
.executable(name: "CIFAR", targets: ["CIFAR"]),

ResNet/Data.swift

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,61 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Python
15+
import Foundation
1616
import TensorFlow
1717

18-
/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
19-
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
18+
#if canImport(FoundationNetworking)
19+
import FoundationNetworking
20+
#endif
21+
2022
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
21-
let subprocess = Python.import("subprocess")
22-
let path = Python.import("os.path")
23-
let filepath = "\(directory)/cifar-10-batches-py"
24-
let isdir = Bool(path.isdir(filepath))!
25-
if !isdir {
26-
print("Downloading CIFAR data...")
27-
let command = """
28-
wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - \
29-
-C \(directory)
30-
"""
31-
subprocess.call(command, shell: true)
23+
let downloadPath = "\(directory)/cifar-10-batches-bin"
24+
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
25+
26+
guard !directoryExists else { return }
27+
28+
print("Downloading CIFAR dataset...")
29+
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
30+
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
31+
if !archiveExists {
32+
print("Archive missing, downloading...")
33+
do {
34+
let downloadedFile = try Data(
35+
contentsOf: URL(
36+
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
37+
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
38+
} catch {
39+
print("Could not download CIFAR dataset, error: \(error)")
40+
exit(-1)
41+
}
42+
}
43+
44+
print("Archive downloaded, processing...")
45+
46+
#if os(macOS)
47+
let tarLocation = "/usr/bin/tar"
48+
#else
49+
let tarLocation = "/bin/tar"
50+
#endif
51+
52+
let task = Process()
53+
task.executableURL = URL(fileURLWithPath: tarLocation)
54+
task.arguments = ["xzf", archivePath]
55+
do {
56+
try task.run()
57+
task.waitUntilExit()
58+
} catch {
59+
print("CIFAR extraction failed with error: \(error)")
3260
}
33-
}
3461

35-
extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
36-
public var _tfeTensorHandle: _AnyTensorHandle {
37-
TFETensorHandle(_owning: handle._cTensorHandle)
62+
do {
63+
try FileManager.default.removeItem(atPath: archivePath)
64+
} catch {
65+
print("Could not remove archive, error: \(error)")
66+
exit(-1)
3867
}
68+
69+
print("Unarchiving completed")
3970
}
4071

4172
struct Example: TensorGroup {
@@ -56,51 +87,61 @@ struct Example: TensorGroup {
5687
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
5788
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
5889
}
59-
60-
public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
6190
}
6291

63-
// Each CIFAR data file is provided as a Python pickle of NumPy arrays
6492
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
6593
downloadCIFAR10IfNotPresent(to: directory)
66-
let np = Python.import("numpy")
67-
let pickle = Python.import("pickle")
68-
let path = "\(directory)/cifar-10-batches-py/\(name)"
69-
let f = Python.open(path, "rb")
70-
let res = pickle.load(f, encoding: "bytes")
94+
let path = "\(directory)/cifar-10-batches-bin/\(name)"
7195

72-
let bytes = res[Python.bytes("data", encoding: "utf8")]
73-
let labels = res[Python.bytes("labels", encoding: "utf8")]
96+
let imageCount = 10000
97+
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
98+
print("Could not read dataset file: \(name)")
99+
exit(-1)
100+
}
101+
guard fileContents.count == 30_730_000 else {
102+
print(
103+
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
104+
exit(-1)
105+
}
106+
107+
var bytes: [UInt8] = []
108+
var labels: [Int64] = []
109+
110+
let imageByteSize = 3073
111+
for imageIndex in 0..<imageCount {
112+
let baseAddress = imageIndex * imageByteSize
113+
labels.append(Int64(fileContents[baseAddress]))
114+
bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
115+
}
74116

75-
let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
76-
let images = Tensor<UInt8>(numpy: bytes)!
77-
let imageCount = images.shape[0]
117+
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
118+
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)
78119

79-
// reshape and transpose from the provided N(CHW) to TF default NHWC
80-
let imageTensor = Tensor<Float>(images
81-
.reshaped(to: [imageCount, 3, 32, 32])
82-
.transposed(withPermutations: [0, 2, 3, 1]))
120+
// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
121+
let imageTensor = Tensor<Float>(images.transposed(withPermutations: [0, 2, 3, 1]))
83122

84123
let mean = Tensor<Float>([0.485, 0.456, 0.406])
85-
let std = Tensor<Float>([0.229, 0.224, 0.225])
124+
let std = Tensor<Float>([0.229, 0.224, 0.225])
86125
let imagesNormalized = ((imageTensor / 255.0) - mean) / std
87126

88127
return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
89128
}
90129

91130
func loadCIFARTrainingFiles() -> Example {
92-
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
131+
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
93132
return Example(
94133
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
95134
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
96135
)
97136
}
98137

99138
func loadCIFARTestFile() -> Example {
100-
return loadCIFARFile(named: "test_batch")
139+
return loadCIFARFile(named: "test_batch.bin")
101140
}
102141

103-
func loadCIFAR10() -> (training: Dataset<Example>, test: Dataset<Example>) {
142+
func loadCIFAR10() -> (
143+
training: Dataset<Example>, test: Dataset<Example>
144+
) {
104145
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
105146
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
106147
return (training: trainingDataset, test: testDataset)

ResNet/README.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,13 @@ dataset.
77
## Setup
88

99
You'll need [the latest version][INSTALL] of Swift for TensorFlow
10-
installed and added to your path. Additionally, the data loader requires Python
11-
3.x (rather than Python 2.7), `wget`, and `numpy`.
12-
13-
> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
14-
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
15-
> `homebrew`.
10+
installed and added to your path.
1611

1712
To train the model on CIFAR-10, run:
1813

1914
```
2015
cd swift-models
21-
swift run ResNet
16+
swift run -c release ResNet
2217
```
2318

2419
[INSTALL]: (https://github.com/tensorflow/swift/blob/master/Installation.md)

0 commit comments

Comments
 (0)