8
8
import comm_file
9
9
from torch .nn .parallel import DistributedDataParallel as DDP
10
10
from torch .distributed import ReduceOp , all_reduce
11
+ from pnetcdf import File
12
+ from mpi4py import MPI
13
+
14
+ class PnetCDFDataset (torch .utils .data .Dataset ):
15
+ def __init__ (self , netcdf_file , data_var , label_var , transform = None , comm = MPI .COMM_WORLD ):
16
+ self .netcdf_file = netcdf_file
17
+ self .data_var = data_var
18
+ self .label_var = label_var
19
+ self .transform = transform
20
+ self .comm = comm
21
+
22
+ # Open the NetCDF file
23
+ self .f = File (self .netcdf_file , mode = 'r' , comm = self .comm )
24
+ self .f .begin_indep () # To use independent I/O mode
25
+
26
+ # Get dimensions of the variables
27
+ self .data_shape = self .f .variables [self .data_var ].shape
28
+ self .label_shape = self .f .variables [self .label_var ].shape
29
+
30
+ def __len__ (self ):
31
+ return self .data_shape [0 ]
32
+
33
+ def __getitem__ (self , idx ):
34
+ # Read the data and label at the given index
35
+ image = self .f .variables [self .data_var ][idx , ...]
36
+ label = self .f .variables [self .label_var ][idx ]
37
+
38
+ if self .transform :
39
+ image = self .transform (image )
40
+
41
+ return image , label
42
+
43
+ def close (self ):
44
+ self .f .close ()
11
45
12
46
class Net (nn .Module ):
13
47
def __init__ (self ):
@@ -118,21 +152,26 @@ def main():
118
152
help = 'how many batches to wait before logging training status' )
119
153
parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
120
154
help = 'For Saving the current Model' )
155
+ parser .add_argument ('--netcdf-file' , type = str , default = "../MNIST_data/mnist_images.nc" ,
156
+ help = 'netcdf file storing train and test data' )
121
157
args = parser .parse_args ()
122
158
use_cuda = not args .no_cuda and torch .cuda .is_available ()
123
159
use_mps = not args .no_mps and torch .backends .mps .is_available ()
124
160
125
161
torch .manual_seed (args .seed )
126
-
162
+
127
163
## init comm, rank, nprocs
128
164
comm , device = comm_file .init_parallel ()
129
165
130
166
rank = comm .get_rank ()
131
167
nprocs = comm .get_size ()
168
+ mpi_comm = MPI .COMM_WORLD
169
+ mpi_rank = mpi_comm .Get_rank ()
170
+ mpi_size = mpi_comm .Get_size ()
132
171
133
- print ("nprocs = " , nprocs , " rank = " ,rank ," device = " , device )
172
+ print ("nprocs = " , nprocs , " rank = " ,rank ," device = " , device , " mpi_size = " , mpi_size , " mpi_rank = " , mpi_rank )
134
173
135
- train_kwargs = {'batch_size' : args .batch_size }
174
+ train_kwargs = {'batch_size' : args .batch_size // nprocs }
136
175
test_kwargs = {'batch_size' : args .test_batch_size }
137
176
if use_cuda :
138
177
cuda_kwargs = {'num_workers' : 1 ,
@@ -145,11 +184,12 @@ def main():
145
184
transforms .ToTensor (),
146
185
transforms .Normalize ((0.1307 ,), (0.3081 ,))
147
186
])
148
- dataset1 = datasets .MNIST ('../MNIST_data' , train = True , download = True ,
149
- transform = transform )
150
- dataset2 = datasets .MNIST ('../MNIST_data' , train = False ,
151
- transform = transform )
152
187
188
+ # pnetcdf MNIST datasets
189
+ netcdf_file = args .netcdf_file
190
+ dataset1 = PnetCDFDataset (netcdf_file , 'train_images' , 'train_labels' , transform , mpi_comm )
191
+ dataset2 = PnetCDFDataset (netcdf_file , 'test_images' , 'test_labels' , transform , mpi_comm )
192
+
153
193
# add train distributed sampler
154
194
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset1 , num_replicas = comm .get_size (), rank = comm .get_rank (), shuffle = True )
155
195
test_sampler = torch .utils .data .distributed .DistributedSampler (dataset2 , num_replicas = comm .get_size (), rank = comm .get_rank (), shuffle = False )
@@ -175,8 +215,10 @@ def main():
175
215
if rank == 0 :
176
216
torch .save (model .state_dict (), "mnist_cnn.pt" )
177
217
218
+ # close the file
219
+ dataset1 .close ()
220
+ dataset2 .close ()
178
221
comm .finalize ()
179
222
180
-
181
223
if __name__ == '__main__' :
182
224
main ()
0 commit comments