1
- // Copyright 2018 The TensorFlow Authors. All Rights Reserved.
1
+ // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
- import Python
15
+ import Foundation
16
16
import TensorFlow
17
17
18
- /// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
19
- /// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
18
+ #if canImport(FoundationNetworking)
19
+ import FoundationNetworking
20
+ #endif
21
+
20
22
func downloadCIFAR10IfNotPresent( to directory: String = " . " ) {
21
- let subprocess = Python . import ( " subprocess " )
22
- let path = Python . import ( " os.path " )
23
- let filepath = " \( directory) /cifar-10-batches-py "
24
- let isdir = Bool ( path. isdir ( filepath) ) !
25
- if !isdir {
26
- print ( " Downloading CIFAR data... " )
27
- let command = " wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \( directory) "
28
- subprocess. call ( command, shell: true )
23
+ let downloadPath = " \( directory) /cifar-10-batches-bin "
24
+ let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
25
+
26
+ guard !directoryExists else { return }
27
+
28
+ print ( " Downloading CIFAR dataset... " )
29
+ let archivePath = " \( directory) /cifar-10-binary.tar.gz "
30
+ let archiveExists = FileManager . default. fileExists ( atPath: archivePath)
31
+ if !archiveExists {
32
+ print ( " Archive missing, downloading... " )
33
+ do {
34
+ let downloadedFile = try Data (
35
+ contentsOf: URL (
36
+ string: " https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz " ) !)
37
+ try downloadedFile. write ( to: URL ( fileURLWithPath: archivePath) )
38
+ } catch {
39
+ print ( " Could not download CIFAR dataset, error: \( error) " )
40
+ exit ( - 1 )
41
+ }
42
+ }
43
+
44
+ print ( " Archive downloaded, processing... " )
45
+
46
+ #if os(macOS)
47
+ let tarLocation = " /usr/bin/tar "
48
+ #else
49
+ let tarLocation = " /bin/tar "
50
+ #endif
51
+
52
+ let task = Process ( )
53
+ task. executableURL = URL ( fileURLWithPath: tarLocation)
54
+ task. arguments = [ " xzf " , archivePath]
55
+ do {
56
+ try task. run ( )
57
+ task. waitUntilExit ( )
58
+ } catch {
59
+ print ( " CIFAR extraction failed with error: \( error) " )
29
60
}
30
- }
31
61
32
- extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
33
- public var _tfeTensorHandle : _AnyTensorHandle {
34
- TFETensorHandle ( _owning: handle. _cTensorHandle)
62
+ do {
63
+ try FileManager . default. removeItem ( atPath: archivePath)
64
+ } catch {
65
+ print ( " Could not remove archive, error: \( error) " )
66
+ exit ( - 1 )
35
67
}
68
+
69
+ print ( " Unarchiving completed " )
36
70
}
37
71
38
72
struct Example : TensorGroup {
@@ -53,52 +87,61 @@ struct Example: TensorGroup {
53
87
label = Tensor < Int32 > ( handle: TensorHandle < Int32 > ( handle: _handles [ labelIndex] ) )
54
88
data = Tensor < Float > ( handle: TensorHandle < Float > ( handle: _handles [ dataIndex] ) )
55
89
}
56
-
57
- public var _tensorHandles : [ _AnyTensorHandle ] { [ label. _tfeTensorHandle, data. _tfeTensorHandle] }
58
90
}
59
91
60
- // Each CIFAR data file is provided as a Python pickle of NumPy arrays
61
92
func loadCIFARFile( named name: String , in directory: String = " . " ) -> Example {
62
93
downloadCIFAR10IfNotPresent ( to: directory)
63
- let np = Python . import ( " numpy " )
64
- let pickle = Python . import ( " pickle " )
65
- let path = " \( directory) /cifar-10-batches-py/ \( name) "
66
- let f = Python . open ( path, " rb " )
67
- let res = pickle. load ( f, encoding: " bytes " )
94
+ let path = " \( directory) /cifar-10-batches-bin/ \( name) "
68
95
69
- let bytes = res [ Python . bytes ( " data " , encoding: " utf8 " ) ]
70
- let labels = res [ Python . bytes ( " labels " , encoding: " utf8 " ) ]
96
+ let imageCount = 10000
97
+ guard let fileContents = try ? Data ( contentsOf: URL ( fileURLWithPath: path) ) else {
98
+ print ( " Could not read dataset file: \( name) " )
99
+ exit ( - 1 )
100
+ }
101
+ guard fileContents. count == 30_730_000 else {
102
+ print (
103
+ " Dataset file \( name) should have 30730000 bytes, instead had \( fileContents. count) " )
104
+ exit ( - 1 )
105
+ }
106
+
107
+ var bytes : [ UInt8 ] = [ ]
108
+ var labels : [ Int64 ] = [ ]
109
+
110
+ let imageByteSize = 3073
111
+ for imageIndex in 0 ..< imageCount {
112
+ let baseAddress = imageIndex * imageByteSize
113
+ labels. append ( Int64 ( fileContents [ baseAddress] ) )
114
+ bytes. append ( contentsOf: fileContents [ ( baseAddress + 1 ) ..< ( baseAddress + 3073 ) ] )
115
+ }
71
116
72
- let labelTensor = Tensor < Int64 > ( numpy: np. array ( labels) ) !
73
- let images = Tensor < UInt8 > ( numpy: bytes) !
74
- let imageCount = images. shape [ 0 ]
117
+ let labelTensor = Tensor < Int64 > ( shape: [ imageCount] , scalars: labels)
118
+ let images = Tensor < UInt8 > ( shape: [ imageCount, 3 , 32 , 32 ] , scalars: bytes)
75
119
76
- // reshape and transpose from the provided N(CHW) to TF default NHWC
77
- let imageTensor = Tensor < Float > ( images
78
- . reshaped ( to: [ imageCount, 3 , 32 , 32 ] )
79
- . transposed ( withPermutations: [ 0 , 2 , 3 , 1 ] ) )
120
+ // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
121
+ let imageTensor = Tensor < Float > ( images. transposed ( withPermutations: [ 0 , 2 , 3 , 1 ] ) )
80
122
81
123
let mean = Tensor < Float > ( [ 0.485 , 0.456 , 0.406 ] )
82
- let std = Tensor < Float > ( [ 0.229 , 0.224 , 0.225 ] )
124
+ let std = Tensor < Float > ( [ 0.229 , 0.224 , 0.225 ] )
83
125
let imagesNormalized = ( ( imageTensor / 255.0 ) - mean) / std
84
126
85
127
return Example ( label: Tensor < Int32 > ( labelTensor) , data: imagesNormalized)
86
128
}
87
129
88
130
func loadCIFARTrainingFiles( ) -> Example {
89
- let data = ( 1 ..< 6 ) . map { loadCIFARFile ( named: " data_batch_ \( $0) " ) }
131
+ let data = ( 1 ..< 6 ) . map { loadCIFARFile ( named: " data_batch_ \( $0) .bin " ) }
90
132
return Example (
91
133
label: Raw . concat ( concatDim: Tensor < Int32 > ( 0 ) , data. map { $0. label } ) ,
92
134
data: Raw . concat ( concatDim: Tensor < Int32 > ( 0 ) , data. map { $0. data } )
93
135
)
94
136
}
95
137
96
138
func loadCIFARTestFile( ) -> Example {
97
- return loadCIFARFile ( named: " test_batch " )
139
+ return loadCIFARFile ( named: " test_batch.bin " )
98
140
}
99
141
100
142
func loadCIFAR10( ) -> (
101
- training: Dataset < Example > , test: Dataset < Example > ) {
143
+ training: Dataset < Example > , test: Dataset < Example >
144
+ ) {
102
145
let trainingDataset = Dataset < Example > ( elements: loadCIFARTrainingFiles ( ) )
103
146
let testDataset = Dataset < Example > ( elements: loadCIFARTestFile ( ) )
104
147
return ( training: trainingDataset, test: testDataset)
0 commit comments