@@ -31,21 +31,27 @@ public struct MNIST {
31
31
32
32
public let batchSize : Int
33
33
34
- public init ( batchSize: Int , flattening: Bool = false , normalizing: Bool = false ) {
34
+ public init (
35
+ batchSize: Int , flattening: Bool = false , normalizing: Bool = false ,
36
+ localStorageDirectory: URL = DatasetUtilities . curentWorkingDirectoryURL
37
+ ) {
35
38
self . batchSize = batchSize
36
39
37
- let ( trainingImages, trainingLabels) = readMNIST (
38
- imagesFile: " train-images-idx3-ubyte " ,
39
- labelsFile: " train-labels-idx1-ubyte " ,
40
+ let ( trainingImages, trainingLabels) = fetchDataset (
41
+ localStorageDirectory: localStorageDirectory,
42
+ imagesFilename: " train-images-idx3-ubyte " ,
43
+ labelsFilename: " train-labels-idx1-ubyte " ,
40
44
flattening: flattening,
41
45
normalizing: normalizing)
46
+
42
47
self . trainingImages = trainingImages
43
48
self . trainingLabels = trainingLabels
44
49
self . trainingSize = Int ( trainingLabels. shape [ 0 ] )
45
50
46
- let ( testImages, testLabels) = readMNIST (
47
- imagesFile: " t10k-images-idx3-ubyte " ,
48
- labelsFile: " t10k-labels-idx1-ubyte " ,
51
+ let ( testImages, testLabels) = fetchDataset (
52
+ localStorageDirectory: localStorageDirectory,
53
+ imagesFilename: " t10k-images-idx3-ubyte " ,
54
+ labelsFilename: " t10k-labels-idx1-ubyte " ,
49
55
flattening: flattening,
50
56
normalizing: normalizing)
51
57
self . testImages = testImages
@@ -61,36 +67,31 @@ extension Tensor {
61
67
}
62
68
}
63
69
64
- /// Reads a file into an array of bytes.
65
- func readFile( _ path: String , possibleDirectories: [ String ] ) -> [ UInt8 ] {
66
- for folder in possibleDirectories {
67
- let parent = URL ( fileURLWithPath: folder)
68
- let filePath = parent. appendingPathComponent ( path)
69
- guard FileManager . default. fileExists ( atPath: filePath. path) else {
70
- continue
71
- }
72
- let data = try ! Data ( contentsOf: filePath, options: [ ] )
73
- return [ UInt8] ( data)
70
+ fileprivate func fetchDataset(
71
+ localStorageDirectory: URL ,
72
+ imagesFilename: String ,
73
+ labelsFilename: String ,
74
+ flattening: Bool ,
75
+ normalizing: Bool
76
+ ) -> ( images: Tensor < Float > , labels: Tensor < Int32 > ) {
77
+ guard let remoteRoot: URL = URL ( string: " http://yann.lecun.com/exdb/mnist " ) else {
78
+ fatalError ( " Failed to create MNST root url: http://yann.lecun.com/exdb/mnist " )
74
79
}
75
- print ( " File not found: \( path) " )
76
- exit ( - 1 )
77
- }
78
80
79
- /// Reads MNIST images and labels from specified file paths.
80
- func readMNIST( imagesFile: String , labelsFile: String , flattening: Bool , normalizing: Bool ) -> (
81
- images: Tensor < Float > ,
82
- labels: Tensor < Int32 >
83
- ) {
84
- print ( " Reading data from files: \( imagesFile) , \( labelsFile) . " )
85
- let images = readFile ( imagesFile, possibleDirectories: [ " . " , " ./Datasets/MNIST " ] ) . dropFirst ( 16 )
86
- . map ( Float . init)
87
- let labels = readFile ( labelsFile, possibleDirectories: [ " . " , " ./Datasets/MNIST " ] ) . dropFirst ( 8 )
88
- . map ( Int32 . init)
89
- let rowCount = labels. count
90
- let imageHeight = 28
91
- let imageWidth = 28
81
+ let imagesData = DatasetUtilities . fetchResource (
82
+ filename: imagesFilename,
83
+ remoteRoot: remoteRoot,
84
+ localStorageDirectory: localStorageDirectory)
85
+ let labelsData = DatasetUtilities . fetchResource (
86
+ filename: labelsFilename,
87
+ remoteRoot: remoteRoot,
88
+ localStorageDirectory: localStorageDirectory)
89
+
90
+ let images = [ UInt8] ( imagesData) . dropFirst ( 16 ) . map ( Float . init)
91
+ let labels = [ UInt8] ( labelsData) . dropFirst ( 8 ) . map ( Int32 . init)
92
92
93
- print ( " Constructing data tensors. " )
93
+ let rowCount = labels. count
94
+ let ( imageWidth, imageHeight) = ( 28 , 28 )
94
95
95
96
if flattening {
96
97
var flattenedImages = Tensor ( shape: [ rowCount, imageHeight * imageWidth] , scalars: images)
@@ -101,8 +102,9 @@ func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normali
101
102
return ( images: flattenedImages, labels: Tensor ( labels) )
102
103
} else {
103
104
return (
104
- images: Tensor ( shape: [ rowCount, 1 , imageHeight, imageWidth] , scalars: images)
105
- . transposed ( withPermutations: [ 0 , 2 , 3 , 1 ] ) / 255 , // NHWC
105
+ images:
106
+ Tensor ( shape: [ rowCount, 1 , imageHeight, imageWidth] , scalars: images)
107
+ . transposed ( withPermutations: [ 0 , 2 , 3 , 1 ] ) / 255 , // NHWC
106
108
labels: Tensor ( labels)
107
109
)
108
110
}
0 commit comments