Skip to content

Commit a4cd4f7

Browse files
committed
switch to use PnetCDF-Python to load MNIST data
1 parent cf838ad commit a4cd4f7

File tree

1 file changed

+50
-8
lines changed

1 file changed

+50
-8
lines changed

examples/MNIST/MNIST_codes/main.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,40 @@
88
import comm_file
99
from torch.nn.parallel import DistributedDataParallel as DDP
1010
from torch.distributed import ReduceOp, all_reduce
11+
from pnetcdf import File
12+
from mpi4py import MPI
13+
14+
class PnetCDFDataset(torch.utils.data.Dataset):
15+
def __init__(self, netcdf_file, data_var, label_var, transform=None, comm=MPI.COMM_WORLD):
16+
self.netcdf_file = netcdf_file
17+
self.data_var = data_var
18+
self.label_var = label_var
19+
self.transform = transform
20+
self.comm = comm
21+
22+
# Open the NetCDF file
23+
self.f = File(self.netcdf_file, mode='r', comm=self.comm)
24+
self.f.begin_indep() # To use independent I/O mode
25+
26+
# Get dimensions of the variables
27+
self.data_shape = self.f.variables[self.data_var].shape
28+
self.label_shape = self.f.variables[self.label_var].shape
29+
30+
def __len__(self):
31+
return self.data_shape[0]
32+
33+
def __getitem__(self, idx):
34+
# Read the data and label at the given index
35+
image = self.f.variables[self.data_var][idx, ...]
36+
label = self.f.variables[self.label_var][idx]
37+
38+
if self.transform:
39+
image = self.transform(image)
40+
41+
return image, label
42+
43+
def close(self):
44+
self.f.close()
1145

1246
class Net(nn.Module):
1347
def __init__(self):
@@ -118,21 +152,26 @@ def main():
118152
help='how many batches to wait before logging training status')
119153
parser.add_argument('--save-model', action='store_true', default=False,
120154
help='For Saving the current Model')
155+
parser.add_argument('--netcdf-file', type=str, default="../MNIST_data/mnist_images.nc",
156+
help='netcdf file storing train and test data')
121157
args = parser.parse_args()
122158
use_cuda = not args.no_cuda and torch.cuda.is_available()
123159
use_mps = not args.no_mps and torch.backends.mps.is_available()
124160

125161
torch.manual_seed(args.seed)
126-
162+
127163
## init comm, rank, nprocs
128164
comm, device = comm_file.init_parallel()
129165

130166
rank = comm.get_rank()
131167
nprocs = comm.get_size()
168+
mpi_comm = MPI.COMM_WORLD
169+
mpi_rank = mpi_comm.Get_rank()
170+
mpi_size = mpi_comm.Get_size()
132171

133-
print("nprocs = ", nprocs, " rank = ",rank," device = ", device)
172+
print("nprocs = ", nprocs, " rank = ",rank," device = ", device, " mpi_size = ", mpi_size, " mpi_rank = ", mpi_rank)
134173

135-
train_kwargs = {'batch_size': args.batch_size}
174+
train_kwargs = {'batch_size': args.batch_size//nprocs}
136175
test_kwargs = {'batch_size': args.test_batch_size}
137176
if use_cuda:
138177
cuda_kwargs = {'num_workers': 1,
@@ -145,11 +184,12 @@ def main():
145184
transforms.ToTensor(),
146185
transforms.Normalize((0.1307,), (0.3081,))
147186
])
148-
dataset1 = datasets.MNIST('../MNIST_data', train=True, download=True,
149-
transform=transform)
150-
dataset2 = datasets.MNIST('../MNIST_data', train=False,
151-
transform=transform)
152187

188+
# pnetcdf MNIST datasets
189+
netcdf_file = args.netcdf_file
190+
dataset1 = PnetCDFDataset(netcdf_file, 'train_images', 'train_labels', transform, mpi_comm)
191+
dataset2 = PnetCDFDataset(netcdf_file, 'test_images', 'test_labels', transform, mpi_comm)
192+
153193
# add train distributed sampler
154194
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset1, num_replicas=comm.get_size(), rank=comm.get_rank(), shuffle=True)
155195
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset2, num_replicas=comm.get_size(), rank=comm.get_rank(), shuffle=False)
@@ -175,8 +215,10 @@ def main():
175215
if rank == 0:
176216
torch.save(model.state_dict(), "mnist_cnn.pt")
177217

218+
# close the file
219+
dataset1.close()
220+
dataset2.close()
178221
comm.finalize()
179222

180-
181223
if __name__ == '__main__':
182224
main()

0 commit comments

Comments
 (0)