|
1 |
| -import os |
2 |
| -import numpy as np |
| 1 | +# |
| 2 | +# Copyright (C) 2024, Northwestern University and Argonne National Laboratory |
| 3 | +# See COPYRIGHT notice in top-level directory. |
| 4 | +# |
| 5 | + |
| 6 | +import os, argparse, struct |
3 | 7 | import numpy as np
|
4 |
| -import pnetcdf |
5 |
| -from mpi4py import MPI |
6 | 8 | from array import array
|
7 |
| -import struct |
| 9 | + |
| 10 | +from mpi4py import MPI |
| 11 | +import pnetcdf |
8 | 12 |
|
9 | 13 | class MnistDataloader(object):
|
10 | 14 | def __init__(self, training_images_filepath,training_labels_filepath,
|
11 | 15 | test_images_filepath, test_labels_filepath):
|
| 16 | + |
12 | 17 | self.training_images_filepath = training_images_filepath
|
13 | 18 | self.training_labels_filepath = training_labels_filepath
|
14 | 19 | self.test_images_filepath = test_images_filepath
|
15 | 20 | self.test_labels_filepath = test_labels_filepath
|
16 |
| - |
17 |
| - def read_images_labels(self, images_filepath, labels_filepath): |
| 21 | + |
| 22 | + def read_images_labels(self, images_filepath, labels_filepath): |
18 | 23 | labels = []
|
19 | 24 | with open(labels_filepath, 'rb') as file:
|
20 | 25 | magic, size = struct.unpack(">II", file.read(8))
|
21 | 26 | if magic != 2049:
|
22 | 27 | raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
|
23 |
| - labels = array("B", file.read()) |
24 |
| - |
| 28 | + labels = array("B", file.read()) |
| 29 | + |
25 | 30 | with open(images_filepath, 'rb') as file:
|
26 | 31 | magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
|
27 | 32 | if magic != 2051:
|
28 | 33 | raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
|
29 |
| - image_data = array("B", file.read()) |
| 34 | + image_data = array("B", file.read()) |
30 | 35 | images = []
|
31 | 36 | for i in range(size):
|
32 | 37 | images.append([0] * rows * cols)
|
33 | 38 | for i in range(size):
|
34 | 39 | img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
|
35 | 40 | img = img.reshape(28, 28)
|
36 |
| - images[i][:] = img |
37 |
| - |
| 41 | + images[i][:] = img |
| 42 | + |
38 | 43 | return images, labels
|
39 |
| - |
| 44 | + |
40 | 45 | def load_data(self):
|
41 | 46 | x_train, y_train = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
|
42 | 47 | x_test, y_test = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
|
43 |
| - return (x_train, y_train),(x_test, y_test) |
44 |
| - |
45 |
| -# |
46 |
| -# Set file paths based on added MNIST Datasets |
47 |
| -# |
48 |
| -input_path = '.' |
49 |
| -training_images_filepath = os.path.join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte') |
50 |
| -training_labels_filepath = os.path.join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte') |
51 |
| -test_images_filepath = os.path.join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte') |
52 |
| -test_labels_filepath = os.path.join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte') |
53 |
| - |
54 |
| -# |
55 |
| -# Load MINST dataset |
56 |
| -# |
57 |
| -mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath) |
58 |
| -(x_train, y_train), (x_test, y_test) = mnist_dataloader.load_data() |
| 48 | + return (x_train, y_train),(x_test, y_test) |
59 | 49 |
|
60 |
| -# use partial dataset |
61 |
| -x_train_small = x_train[:60] |
62 |
| -y_train_small = y_train[:60] |
63 |
| -x_test_small = x_test[:12] |
64 |
| -y_test_small = y_test[:12] |
65 | 50 |
|
66 |
| -def to_nc(train_samples, test_samples, train_labels, test_labels, comm, out_file_path='mnist_images.nc'): |
| 51 | +def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path='mnist_images.nc'): |
67 | 52 | if os.path.exists(out_file_path):
|
68 | 53 | os.remove(out_file_path)
|
| 54 | + |
69 | 55 | train_labels = list(train_labels)
|
70 | 56 | test_labels = list(test_labels)
|
71 |
| - with pnetcdf.File(out_file_path, comm= comm, mode = "w", format = "64BIT_DATA") as fnc: |
72 |
| - |
73 |
| - dim_y = fnc.def_dim("Y", 28) |
74 |
| - dim_x = fnc.def_dim("X", 28) |
75 |
| - dim_num_train = fnc.def_dim("train_idx", len(train_samples)) |
76 |
| - dim_num_test = fnc.def_dim("test_idx", len(test_samples)) |
77 |
| - |
78 |
| - # define nc variable for all imgs |
79 |
| - v_train = fnc.def_var("train_images", pnetcdf.NC_UBYTE, (dim_num_train, dim_x, dim_y)) |
80 |
| - # put labels into attributes |
81 |
| - v_label_train = fnc.def_var("train_labels", pnetcdf.NC_UBYTE, (dim_num_train, )) |
82 |
| - |
83 |
| - # define nc variable for all imgs |
84 |
| - v_test = fnc.def_var("test_images", pnetcdf.NC_UBYTE, (dim_num_test, dim_x, dim_y)) |
85 |
| - # put labels into attributes |
86 |
| - v_label_test = fnc.def_var("test_labels", pnetcdf.NC_UBYTE, (dim_num_test, )) |
87 |
| - |
88 |
| - # put values into each nc variable |
| 57 | + |
| 58 | + with pnetcdf.File(out_file_path, mode = "w", format = "NC_64BIT_DATA") as fnc: |
| 59 | + |
| 60 | + # Each image is of dimension 28 x 28 |
| 61 | + dim_y = fnc.def_dim("height", 28) |
| 62 | + dim_x = fnc.def_dim("width", 28) |
| 63 | + |
| 64 | + # define number of traing and testing samples |
| 65 | + dim_train = fnc.def_dim("train_num", len(train_samples)) |
| 66 | + dim_test = fnc.def_dim("test_num", len(test_samples)) |
| 67 | + |
| 68 | + # define nc variables to store training image samples and labels |
| 69 | + train_data = fnc.def_var("train_samples", pnetcdf.NC_UBYTE, (dim_train, dim_y, dim_x)) |
| 70 | + train_data.long_name = "training data samples" |
| 71 | + train_label = fnc.def_var("train_labels", pnetcdf.NC_UBYTE, (dim_train)) |
| 72 | + train_label.long_name = "labels of training samples" |
| 73 | + |
| 74 | + # define nc variables to store testing image samples and labels |
| 75 | + test_data = fnc.def_var("test_samples", pnetcdf.NC_UBYTE, (dim_test, dim_y, dim_x)) |
| 76 | + test_data.long_name = "testing data samples" |
| 77 | + test_label = fnc.def_var("test_labels", pnetcdf.NC_UBYTE, (dim_test)) |
| 78 | + test_label.long_name = "labels of testing samples" |
| 79 | + |
| 80 | + # exit define mode and enter data mode |
89 | 81 | fnc.enddef()
|
90 |
| - v_label_train[:] = np.array(train_labels, dtype = np.uint8) |
| 82 | + |
| 83 | + # write training data samples |
91 | 84 | for idx, img in enumerate(train_samples):
|
92 |
| - v_train[idx, :, :] = img |
93 |
| - |
94 |
| - v_label_test[:] = np.array(test_labels, dtype = np.uint8) |
| 85 | + train_data[idx, :, :] = img |
| 86 | + |
| 87 | + # write labels of training data samples |
| 88 | + train_label[:] = np.array(train_labels, dtype = np.uint8) |
| 89 | + |
| 90 | + # write testing data samples |
95 | 91 | for idx, img in enumerate(test_samples):
|
96 |
| - v_test[idx, :, :] = img |
| 92 | + test_data[idx, :, :] = img |
| 93 | + |
| 94 | + # write labels of testing data samples |
| 95 | + test_label[:] = np.array(test_labels, dtype = np.uint8) |
| 96 | + |
| 97 | + |
| 98 | +if __name__ == '__main__': |
| 99 | + |
| 100 | + # parse command-line arguments |
| 101 | + args = None |
| 102 | + parser = argparse.ArgumentParser(description='Store MNIST Datasets to a NetCDF file') |
| 103 | + parser.add_argument("--verbose", help="Verbose mode", action="store_true") |
| 104 | + parser.add_argument('--train-size', type=int, default=60, metavar='N', |
| 105 | + help='Number of training samples extracted from the input file (default: 60)') |
| 106 | + parser.add_argument('--test-size', type=int, default=12, metavar='N', |
| 107 | + help='Number of testing samples extracted from the input file (default: 12)') |
| 108 | + parser.add_argument("--train-data-file", nargs=1, type=str, help="(Optional) input file name of training data",\ |
| 109 | + default = "train-images-idx3-ubyte") |
| 110 | + parser.add_argument("--train-label-file", nargs=1, type=str, help="(Optional) input file name of training labels",\ |
| 111 | + default = "train-labels-idx1-ubyte") |
| 112 | + parser.add_argument("--test-data-file", nargs=1, type=str, help="(Optional) input file name of testing data",\ |
| 113 | + default = "t10k-images-idx3-ubyte") |
| 114 | + parser.add_argument("--test-label-file", nargs=1, type=str, help="(Optional) input file name of testing labels",\ |
| 115 | + default = "t10k-labels-idx1-ubyte") |
| 116 | + args = parser.parse_args() |
| 117 | + |
| 118 | + verbose = True if args.verbose else False |
| 119 | + |
| 120 | + if verbose: |
| 121 | + print("Input file of training samples: ", args.train_data_file) |
| 122 | + print("Input file of training labels: ", args.train_label_file) |
| 123 | + print("Input file of testing samples: ", args.test_data_file) |
| 124 | + print("Input file of testing labels: ", args.test_label_file) |
| 125 | + |
| 126 | + # |
| 127 | + # Load MINST dataset |
| 128 | + # |
| 129 | + mnist_dataloader = MnistDataloader(args.train_data_file, |
| 130 | + args.train_label_file, |
| 131 | + args.test_data_file, |
| 132 | + args.test_label_file) |
| 133 | + |
| 134 | + (train_data, train_label), (test_data, test_label) = mnist_dataloader.load_data() |
| 135 | + |
| 136 | + n_train = len(train_data) |
| 137 | + if args.train_size > 0 and args.train_size < n_train: |
| 138 | + n_train = int(args.train_size) |
97 | 139 |
|
98 |
| -comm = MPI.COMM_WORLD |
99 |
| -rank = comm.Get_rank() |
100 |
| -size = comm.Get_size() |
| 140 | + n_test = len(test_data) |
| 141 | + if args.test_size > 0 and args.test_size < n_test: |
| 142 | + n_test = int(args.test_size) |
101 | 143 |
|
102 |
| -# create mini MNIST file |
103 |
| -to_nc(x_train_small, x_test_small, y_train_small, y_test_small, comm, "mnist_images_mini.nc") |
| 144 | + if verbose: |
| 145 | + print("Number of training samples: ", n_train) |
| 146 | + print("Number of testing samples: ", n_test) |
104 | 147 |
|
105 |
| -# create MNIST file |
106 |
| -# to_nc(x_train, x_test, y_train, y_test, comm, "mnist_images.nc") |
| 148 | + # |
| 149 | + # create mini MNIST file in NetCDF format |
| 150 | + # |
| 151 | + to_nc(train_data[0:n_train], train_label[0:n_train], |
| 152 | + test_data[0:n_test], test_label[0:n_test], "mnist_images.nc") |
107 | 153 |
|
108 | 154 |
|
0 commit comments