Skip to content

Commit 31d744a

Browse files
committed
revise netcdf file generation utility program
1 parent 6e08b77 commit 31d744a

File tree

1 file changed

+111
-65
lines changed

1 file changed

+111
-65
lines changed

examples/MNIST/create_mnist_netcdf.py

Lines changed: 111 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,154 @@
1-
import os
2-
import numpy as np
1+
#
2+
# Copyright (C) 2024, Northwestern University and Argonne National Laboratory
3+
# See COPYRIGHT notice in top-level directory.
4+
#
5+
6+
import os, argparse, struct
37
import numpy as np
4-
import pnetcdf
5-
from mpi4py import MPI
68
from array import array
7-
import struct
9+
10+
from mpi4py import MPI
11+
import pnetcdf
812

913
class MnistDataloader(object):
1014
def __init__(self, training_images_filepath,training_labels_filepath,
1115
test_images_filepath, test_labels_filepath):
16+
1217
self.training_images_filepath = training_images_filepath
1318
self.training_labels_filepath = training_labels_filepath
1419
self.test_images_filepath = test_images_filepath
1520
self.test_labels_filepath = test_labels_filepath
16-
17-
def read_images_labels(self, images_filepath, labels_filepath):
21+
22+
def read_images_labels(self, images_filepath, labels_filepath):
1823
labels = []
1924
with open(labels_filepath, 'rb') as file:
2025
magic, size = struct.unpack(">II", file.read(8))
2126
if magic != 2049:
2227
raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
23-
labels = array("B", file.read())
24-
28+
labels = array("B", file.read())
29+
2530
with open(images_filepath, 'rb') as file:
2631
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
2732
if magic != 2051:
2833
raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
29-
image_data = array("B", file.read())
34+
image_data = array("B", file.read())
3035
images = []
3136
for i in range(size):
3237
images.append([0] * rows * cols)
3338
for i in range(size):
3439
img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
3540
img = img.reshape(28, 28)
36-
images[i][:] = img
37-
41+
images[i][:] = img
42+
3843
return images, labels
39-
44+
4045
def load_data(self):
4146
x_train, y_train = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
4247
x_test, y_test = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
43-
return (x_train, y_train),(x_test, y_test)
44-
45-
#
46-
# Set file paths based on added MNIST Datasets
47-
#
48-
input_path = '.'
49-
training_images_filepath = os.path.join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte')
50-
training_labels_filepath = os.path.join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
51-
test_images_filepath = os.path.join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte')
52-
test_labels_filepath = os.path.join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')
53-
54-
#
55-
# Load MINST dataset
56-
#
57-
mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath)
58-
(x_train, y_train), (x_test, y_test) = mnist_dataloader.load_data()
48+
return (x_train, y_train),(x_test, y_test)
5949

60-
# use partial dataset
61-
x_train_small = x_train[:60]
62-
y_train_small = y_train[:60]
63-
x_test_small = x_test[:12]
64-
y_test_small = y_test[:12]
6550

66-
def to_nc(train_samples, test_samples, train_labels, test_labels, comm, out_file_path='mnist_images.nc'):
51+
def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path='mnist_images.nc'):
6752
if os.path.exists(out_file_path):
6853
os.remove(out_file_path)
54+
6955
train_labels = list(train_labels)
7056
test_labels = list(test_labels)
71-
with pnetcdf.File(out_file_path, comm= comm, mode = "w", format = "64BIT_DATA") as fnc:
72-
73-
dim_y = fnc.def_dim("Y", 28)
74-
dim_x = fnc.def_dim("X", 28)
75-
dim_num_train = fnc.def_dim("train_idx", len(train_samples))
76-
dim_num_test = fnc.def_dim("test_idx", len(test_samples))
77-
78-
# define nc variable for all imgs
79-
v_train = fnc.def_var("train_images", pnetcdf.NC_UBYTE, (dim_num_train, dim_x, dim_y))
80-
# put labels into attributes
81-
v_label_train = fnc.def_var("train_labels", pnetcdf.NC_UBYTE, (dim_num_train, ))
82-
83-
# define nc variable for all imgs
84-
v_test = fnc.def_var("test_images", pnetcdf.NC_UBYTE, (dim_num_test, dim_x, dim_y))
85-
# put labels into attributes
86-
v_label_test = fnc.def_var("test_labels", pnetcdf.NC_UBYTE, (dim_num_test, ))
87-
88-
# put values into each nc variable
57+
58+
with pnetcdf.File(out_file_path, mode = "w", format = "NC_64BIT_DATA") as fnc:
59+
60+
# Each image is of dimension 28 x 28
61+
dim_y = fnc.def_dim("height", 28)
62+
dim_x = fnc.def_dim("width", 28)
63+
64+
# define number of traing and testing samples
65+
dim_train = fnc.def_dim("train_num", len(train_samples))
66+
dim_test = fnc.def_dim("test_num", len(test_samples))
67+
68+
# define nc variables to store training image samples and labels
69+
train_data = fnc.def_var("train_samples", pnetcdf.NC_UBYTE, (dim_train, dim_y, dim_x))
70+
train_data.long_name = "training data samples"
71+
train_label = fnc.def_var("train_labels", pnetcdf.NC_UBYTE, (dim_train))
72+
train_label.long_name = "labels of training samples"
73+
74+
# define nc variables to store testing image samples and labels
75+
test_data = fnc.def_var("test_samples", pnetcdf.NC_UBYTE, (dim_test, dim_y, dim_x))
76+
test_data.long_name = "testing data samples"
77+
test_label = fnc.def_var("test_labels", pnetcdf.NC_UBYTE, (dim_test))
78+
test_label.long_name = "labels of testing samples"
79+
80+
# exit define mode and enter data mode
8981
fnc.enddef()
90-
v_label_train[:] = np.array(train_labels, dtype = np.uint8)
82+
83+
# write training data samples
9184
for idx, img in enumerate(train_samples):
92-
v_train[idx, :, :] = img
93-
94-
v_label_test[:] = np.array(test_labels, dtype = np.uint8)
85+
train_data[idx, :, :] = img
86+
87+
# write labels of training data samples
88+
train_label[:] = np.array(train_labels, dtype = np.uint8)
89+
90+
# write testing data samples
9591
for idx, img in enumerate(test_samples):
96-
v_test[idx, :, :] = img
92+
test_data[idx, :, :] = img
93+
94+
# write labels of testing data samples
95+
test_label[:] = np.array(test_labels, dtype = np.uint8)
96+
97+
98+
if __name__ == '__main__':
99+
100+
# parse command-line arguments
101+
args = None
102+
parser = argparse.ArgumentParser(description='Store MNIST Datasets to a NetCDF file')
103+
parser.add_argument("--verbose", help="Verbose mode", action="store_true")
104+
parser.add_argument('--train-size', type=int, default=60, metavar='N',
105+
help='Number of training samples extracted from the input file (default: 60)')
106+
parser.add_argument('--test-size', type=int, default=12, metavar='N',
107+
help='Number of testing samples extracted from the input file (default: 12)')
108+
parser.add_argument("--train-data-file", nargs=1, type=str, help="(Optional) input file name of training data",\
109+
default = "train-images-idx3-ubyte")
110+
parser.add_argument("--train-label-file", nargs=1, type=str, help="(Optional) input file name of training labels",\
111+
default = "train-labels-idx1-ubyte")
112+
parser.add_argument("--test-data-file", nargs=1, type=str, help="(Optional) input file name of testing data",\
113+
default = "t10k-images-idx3-ubyte")
114+
parser.add_argument("--test-label-file", nargs=1, type=str, help="(Optional) input file name of testing labels",\
115+
default = "t10k-labels-idx1-ubyte")
116+
args = parser.parse_args()
117+
118+
verbose = True if args.verbose else False
119+
120+
if verbose:
121+
print("Input file of training samples: ", args.train_data_file)
122+
print("Input file of training labels: ", args.train_label_file)
123+
print("Input file of testing samples: ", args.test_data_file)
124+
print("Input file of testing labels: ", args.test_label_file)
125+
126+
#
127+
# Load MINST dataset
128+
#
129+
mnist_dataloader = MnistDataloader(args.train_data_file,
130+
args.train_label_file,
131+
args.test_data_file,
132+
args.test_label_file)
133+
134+
(train_data, train_label), (test_data, test_label) = mnist_dataloader.load_data()
135+
136+
n_train = len(train_data)
137+
if args.train_size > 0 and args.train_size < n_train:
138+
n_train = int(args.train_size)
97139

98-
comm = MPI.COMM_WORLD
99-
rank = comm.Get_rank()
100-
size = comm.Get_size()
140+
n_test = len(test_data)
141+
if args.test_size > 0 and args.test_size < n_test:
142+
n_test = int(args.test_size)
101143

102-
# create mini MNIST file
103-
to_nc(x_train_small, x_test_small, y_train_small, y_test_small, comm, "mnist_images_mini.nc")
144+
if verbose:
145+
print("Number of training samples: ", n_train)
146+
print("Number of testing samples: ", n_test)
104147

105-
# create MNIST file
106-
# to_nc(x_train, x_test, y_train, y_test, comm, "mnist_images.nc")
148+
#
149+
# create mini MNIST file in NetCDF format
150+
#
151+
to_nc(train_data[0:n_train], train_label[0:n_train],
152+
test_data[0:n_test], test_label[0:n_test], "mnist_images.nc")
107153

108154

0 commit comments

Comments
 (0)