Skip to content

Commit bff8af7

Browse files
committed
add code for generating MNIST netcdf file
1 parent 1e28572 commit bff8af7

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

examples/MNIST/create_mnist_netcdf.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import numpy as np
3+
import numpy as np
4+
import pnetcdf
5+
from mpi4py import MPI
6+
from array import array
7+
import struct
8+
9+
class MnistDataloader(object):
10+
def __init__(self, training_images_filepath,training_labels_filepath,
11+
test_images_filepath, test_labels_filepath):
12+
self.training_images_filepath = training_images_filepath
13+
self.training_labels_filepath = training_labels_filepath
14+
self.test_images_filepath = test_images_filepath
15+
self.test_labels_filepath = test_labels_filepath
16+
17+
def read_images_labels(self, images_filepath, labels_filepath):
18+
labels = []
19+
with open(labels_filepath, 'rb') as file:
20+
magic, size = struct.unpack(">II", file.read(8))
21+
if magic != 2049:
22+
raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
23+
labels = array("B", file.read())
24+
25+
with open(images_filepath, 'rb') as file:
26+
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
27+
if magic != 2051:
28+
raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
29+
image_data = array("B", file.read())
30+
images = []
31+
for i in range(size):
32+
images.append([0] * rows * cols)
33+
for i in range(size):
34+
img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
35+
img = img.reshape(28, 28)
36+
images[i][:] = img
37+
38+
return images, labels
39+
40+
def load_data(self):
41+
x_train, y_train = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
42+
x_test, y_test = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
43+
return (x_train, y_train),(x_test, y_test)
44+
45+
#
46+
# Set file paths based on added MNIST Datasets
47+
#
48+
input_path = '.'
49+
training_images_filepath = os.path.join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte')
50+
training_labels_filepath = os.path.join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
51+
test_images_filepath = os.path.join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte')
52+
test_labels_filepath = os.path.join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')
53+
54+
#
55+
# Load MINST dataset
56+
#
57+
mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath)
58+
(x_train, y_train), (x_test, y_test) = mnist_dataloader.load_data()
59+
60+
# use partial dataset
61+
x_train_small = x_train[:60]
62+
y_train_small = y_train[:60]
63+
x_test_small = x_test[:12]
64+
y_test_small = y_test[:12]
65+
66+
def to_nc(train_samples, test_samples, train_labels, test_labels, comm, out_file_path='mnist_images.nc'):
67+
if os.path.exists(out_file_path):
68+
os.remove(out_file_path)
69+
train_labels = list(train_labels)
70+
test_labels = list(test_labels)
71+
with pnetcdf.File(out_file_path, comm= comm, mode = "w", format = "64BIT_DATA") as fnc:
72+
73+
dim_y = fnc.def_dim("Y", 28)
74+
dim_x = fnc.def_dim("X", 28)
75+
dim_num_train = fnc.def_dim("train_idx", len(train_samples))
76+
dim_num_test = fnc.def_dim("test_idx", len(test_samples))
77+
78+
# define nc variable for all imgs
79+
v_train = fnc.def_var("train_images", pnetcdf.NC_UBYTE, (dim_num_train, dim_x, dim_y))
80+
# put labels into attributes
81+
v_label_train = fnc.def_var("train_labels", pnetcdf.NC_UBYTE, (dim_num_train, ))
82+
83+
# define nc variable for all imgs
84+
v_test = fnc.def_var("test_images", pnetcdf.NC_UBYTE, (dim_num_test, dim_x, dim_y))
85+
# put labels into attributes
86+
v_label_test = fnc.def_var("test_labels", pnetcdf.NC_UBYTE, (dim_num_test, ))
87+
88+
# put values into each nc variable
89+
fnc.enddef()
90+
v_label_train[:] = np.array(train_labels, dtype = np.uint8)
91+
for idx, img in enumerate(train_samples):
92+
v_train[idx, :, :] = img
93+
94+
v_label_test[:] = np.array(test_labels, dtype = np.uint8)
95+
for idx, img in enumerate(test_samples):
96+
v_test[idx, :, :] = img
97+
98+
comm = MPI.COMM_WORLD
99+
rank = comm.Get_rank()
100+
size = comm.Get_size()
101+
102+
# create mini MNIST file
103+
to_nc(x_train_small, x_test_small, y_train_small, y_test_small, comm, "mnist_images_mini.nc")
104+
105+
# create MNIST file
106+
# to_nc(x_train, x_test, y_train, y_test, comm, "mnist_images.nc")
107+
108+

0 commit comments

Comments
 (0)