5
5
import torch .optim as optim
6
6
from torchvision import datasets , transforms
7
7
from torch .optim .lr_scheduler import StepLR
8
-
8
+ import comm_file
9
+ from torch .nn .parallel import DistributedDataParallel as DDP
10
+ from torch .distributed import ReduceOp , all_reduce
9
11
10
12
class Net (nn .Module ):
11
13
def __init__ (self ):
@@ -33,40 +35,62 @@ def forward(self, x):
33
35
return output
34
36
35
37
36
- def train (args , model , device , train_loader , optimizer , epoch ):
38
+ def train (args , model , device , train_loader , optimizer , epoch , comm ):
37
39
model .train ()
40
+ total_loss = 0.0
41
+ num_batches = 0
38
42
for batch_idx , (data , target ) in enumerate (train_loader ):
39
43
data , target = data .to (device ), target .to (device )
40
44
optimizer .zero_grad ()
41
45
output = model (data )
42
46
loss = F .nll_loss (output , target )
43
47
loss .backward ()
44
48
optimizer .step ()
45
- if batch_idx % args .log_interval == 0 :
46
- print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
47
- epoch , batch_idx * len (data ), len (train_loader .dataset ),
48
- 100. * batch_idx / len (train_loader ), loss .item ()))
49
- if args .dry_run :
50
- break
51
-
52
-
53
- def test (model , device , test_loader ):
49
+
50
+ total_loss += loss .item ()
51
+ num_batches += 1
52
+
53
+ # Compute the average loss for the current epoch
54
+ avg_loss = total_loss / num_batches
55
+
56
+ # Reduce the average loss across all processes
57
+ avg_loss_tensor = torch .tensor (avg_loss , device = device )
58
+ all_reduce (avg_loss_tensor , op = ReduceOp .SUM )
59
+ avg_loss_tensor /= comm .get_size ()
60
+
61
+ # Print the average loss only from the master process
62
+ if comm .get_rank () == 0 :
63
+ print (f'Train Epoch: { epoch } \t Average Loss: { avg_loss_tensor .item ():.6f} ' )
64
+
65
+
66
+ def test (model , device , test_loader , comm ):
54
67
model .eval ()
55
68
test_loss = 0
56
69
correct = 0
70
+ total_samples = 0
57
71
with torch .no_grad ():
58
72
for data , target in test_loader :
59
73
data , target = data .to (device ), target .to (device )
60
74
output = model (data )
61
75
test_loss += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
62
76
pred = output .argmax (dim = 1 , keepdim = True ) # get the index of the max log-probability
63
77
correct += pred .eq (target .view_as (pred )).sum ().item ()
64
-
65
- test_loss /= len (test_loader .dataset )
66
-
67
- print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
68
- test_loss , correct , len (test_loader .dataset ),
69
- 100. * correct / len (test_loader .dataset )))
78
+ total_samples += data .size (0 )
79
+
80
+ test_loss_tensor = torch .tensor (test_loss , device = device )
81
+ correct_tensor = torch .tensor (correct , device = device )
82
+ total_samples_tensor = torch .tensor (total_samples , device = device )
83
+ all_reduce (test_loss_tensor , op = ReduceOp .SUM )
84
+ all_reduce (correct_tensor , op = ReduceOp .SUM )
85
+ all_reduce (total_samples_tensor , op = ReduceOp .SUM )
86
+ test_loss = test_loss_tensor .item ()
87
+ correct = correct_tensor .item ()
88
+ total_samples = total_samples_tensor .item ()
89
+ avg_loss = test_loss / total_samples
90
+ accuracy = 100. * correct / total_samples
91
+
92
+ if comm .get_rank () == 0 :
93
+ print (f'Test set: Average loss: { avg_loss :.4f} , Accuracy: { correct } /{ total_samples } ({ accuracy :.0f} %)\n ' )
70
94
71
95
72
96
def main ():
@@ -100,45 +124,59 @@ def main():
100
124
101
125
torch .manual_seed (args .seed )
102
126
103
- if use_cuda :
104
- device = torch .device ("cuda" )
105
- elif use_mps :
106
- device = torch .device ("mps" )
107
- else :
108
- device = torch .device ("cpu" )
127
+ ## init comm, rank, nprocs
128
+ comm , device = comm_file .init_parallel ()
129
+
130
+ rank = comm .get_rank ()
131
+ nprocs = comm .get_size ()
132
+
133
+ print ("nprocs = " , nprocs , " rank = " ,rank ," device = " , device )
109
134
110
135
train_kwargs = {'batch_size' : args .batch_size }
111
136
test_kwargs = {'batch_size' : args .test_batch_size }
112
137
if use_cuda :
113
138
cuda_kwargs = {'num_workers' : 1 ,
114
139
'pin_memory' : True ,
115
- 'shuffle' : True }
140
+ 'shuffle' : False }
116
141
train_kwargs .update (cuda_kwargs )
117
142
test_kwargs .update (cuda_kwargs )
118
143
119
144
transform = transforms .Compose ([
120
145
transforms .ToTensor (),
121
146
transforms .Normalize ((0.1307 ,), (0.3081 ,))
122
147
])
123
- dataset1 = datasets .MNIST ('../data ' , train = True , download = True ,
148
+ dataset1 = datasets .MNIST ('../MNIST_data ' , train = True , download = True ,
124
149
transform = transform )
125
- dataset2 = datasets .MNIST ('../data ' , train = False ,
150
+ dataset2 = datasets .MNIST ('../MNIST_data ' , train = False ,
126
151
transform = transform )
127
- train_loader = torch .utils .data .DataLoader (dataset1 ,** train_kwargs )
128
- test_loader = torch .utils .data .DataLoader (dataset2 , ** test_kwargs )
152
+
153
+ # add train distributed sampler
154
+ train_sampler = torch .utils .data .distributed .DistributedSampler (dataset1 , num_replicas = comm .get_size (), rank = comm .get_rank (), shuffle = True )
155
+ test_sampler = torch .utils .data .distributed .DistributedSampler (dataset2 , num_replicas = comm .get_size (), rank = comm .get_rank (), shuffle = False )
156
+ train_loader = torch .utils .data .DataLoader (dataset1 , sampler = train_sampler , ** train_kwargs )
157
+ test_loader = torch .utils .data .DataLoader (dataset2 , sampler = test_sampler , ** test_kwargs , drop_last = False )
129
158
130
159
model = Net ().to (device )
160
+ # add to use DDP
161
+ model = DDP (model , device_ids = [device ] if use_cuda else None )
131
162
optimizer = optim .Adadelta (model .parameters (), lr = args .lr )
132
163
133
164
scheduler = StepLR (optimizer , step_size = 1 , gamma = args .gamma )
134
165
for epoch in range (1 , args .epochs + 1 ):
135
- train (args , model , device , train_loader , optimizer , epoch )
136
- test (model , device , test_loader )
166
+ # train sampler set epoch
167
+ train_sampler .set_epoch (epoch )
168
+ test_sampler .set_epoch (epoch )
169
+
170
+ train (args , model , device , train_loader , optimizer , epoch , comm )
171
+ test (model , device , test_loader , comm )
137
172
scheduler .step ()
138
173
139
174
if args .save_model :
140
- torch .save (model .state_dict (), "mnist_cnn.pt" )
175
+ if rank == 0 :
176
+ torch .save (model .state_dict (), "mnist_cnn.pt" )
177
+
178
+ comm .finalize ()
141
179
142
180
143
181
if __name__ == '__main__' :
144
- main ()
182
+ main ()
0 commit comments