Skip to content

Commit cf838ad

Browse files
committed
use DDP to train MNIST in parallel
1 parent 7783b31 commit cf838ad

File tree

2 files changed

+281
-32
lines changed

2 files changed

+281
-32
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import os
2+
import torch
3+
import torch.distributed as dist
4+
from mpi4py import MPI
5+
6+
class distributed():
7+
def get_size(self):
8+
if dist.is_available() and dist.is_initialized():
9+
size = dist.get_world_size()
10+
else:
11+
size = 1
12+
return size
13+
14+
def get_rank(self):
15+
if dist.is_available() and dist.is_initialized():
16+
rank = dist.get_rank()
17+
else:
18+
rank = 0
19+
return rank
20+
21+
def get_local_rank(self):
22+
if not (dist.is_available() and dist.is_initialized()):
23+
return 0
24+
# Number of GPUs per node
25+
if torch.cuda.is_available():
26+
local_rank = dist.get_rank() % torch.cuda.device_count()
27+
else:
28+
# raise NotImplementedError()
29+
# running on cpu device should not call this function
30+
local_rank = -1
31+
return local_rank
32+
33+
def __init__(self, method):
34+
# MASTER_PORT - required; has to be a free port on machine with rank 0
35+
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
36+
# WORLD_SIZE - required; can be set either here, or in a call to init function
37+
# RANK - required; can be set either here, or in a call to init function
38+
39+
if method == "nccl-slurm":
40+
# MASTER_ADDR can be set in the slurm batch script using command
41+
# scontrol show hostnames $SLURM_JOB_NODELIST
42+
if "MASTER_ADDR" not in os.environ:
43+
# Try SLURM_LAUNCH_NODE_IPADDR but it is the IP address of the node
44+
# from which the task launch was initiated (where the srun command
45+
# ran from). It may not be the node of rank 0.
46+
if "SLURM_LAUNCH_NODE_IPADDR" in os.environ:
47+
os.environ["MASTER_ADDR"] = os.environ["SLURM_LAUNCH_NODE_IPADDR"]
48+
else:
49+
raise Exception("Error: nccl-slurm - SLURM_LAUNCH_NODE_IPADDR is not set")
50+
51+
# Use the default pytorch port
52+
if "MASTER_PORT" not in os.environ:
53+
if "SLURM_SRUN_COMM_PORT" in os.environ:
54+
os.environ["MASTER_PORT"] = os.environ["SLURM_SRUN_COMM_PORT"]
55+
else:
56+
os.environ["MASTER_PORT"] = "29500"
57+
58+
# obtain WORLD_SIZE
59+
if "WORLD_SIZE" not in os.environ:
60+
if "SLURM_NTASKS" in os.environ:
61+
world_size = os.environ["SLURM_NTASKS"]
62+
else:
63+
if "SLURM_JOB_NUM_NODES" in os.environ:
64+
num_nodes = os.environ["SLURM_JOB_NUM_NODES"]
65+
else:
66+
raise Exception("Error: nccl-slurm - SLURM_JOB_NUM_NODES is not set")
67+
if "SLURM_NTASKS_PER_NODE" in os.environ:
68+
ntasks_per_node = os.environ["SLURM_NTASKS_PER_NODE"]
69+
elif "SLURM_TASKS_PER_NODE" in os.environ:
70+
ntasks_per_node = os.environ["SLURM_TASKS_PER_NODE"]
71+
else:
72+
raise Exception("Error: nccl-slurm - SLURM_(N)TASKS_PER_NODE is not set")
73+
world_size = ntasks_per_node * num_nodes
74+
os.environ["WORLD_SIZE"] = str(world_size)
75+
76+
# obtain RANK
77+
if "RANK" not in os.environ:
78+
if "SLURM_PROCID" in os.environ:
79+
os.environ["RANK"] = os.environ["SLURM_PROCID"]
80+
else:
81+
raise Exception("Error: nccl-slurm - SLURM_PROCID is not set")
82+
83+
# Initialize DDP module
84+
dist.init_process_group(backend = "nccl", init_method='env://')
85+
86+
elif method == "nccl-openmpi":
87+
if "MASTER_ADDR" not in os.environ:
88+
if "PMIX_SERVER_URI2" in os.environ:
89+
os.environ["MASTER_ADDR"] = os.environ("PMIX_SERVER_URI2").split("//")[1]
90+
else:
91+
raise Exception("Error: nccl-openmpi - PMIX_SERVER_URI2 is not set")
92+
93+
# Use the default pytorch port
94+
if "MASTER_PORT" not in os.environ:
95+
os.environ["MASTER_PORT"] = "29500"
96+
97+
if "WORLD_SIZE" not in os.environ:
98+
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
99+
raise Exception("Error: nccl-openmpi - OMPI_COMM_WORLD_SIZE is not set")
100+
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
101+
102+
if "RANK" not in os.environ:
103+
if "OMPI_COMM_WORLD_RANK" not in os.environ:
104+
raise Exception("Error: nccl-openmpi - OMPI_COMM_WORLD_RANK is not set")
105+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
106+
107+
# Initialize DDP module
108+
dist.init_process_group(backend = "nccl", init_method='env://')
109+
110+
elif method == "nccl-mpich":
111+
if "MASTER_ADDR" not in os.environ:
112+
os.environ['MASTER_ADDR'] = "localhost"
113+
114+
# Use the default pytorch port
115+
if "MASTER_PORT" not in os.environ:
116+
os.environ["MASTER_PORT"] = "29500"
117+
118+
if "WORLD_SIZE" not in os.environ:
119+
if "PMI_SIZE" in os.environ:
120+
world_size = os.environ["PMI_SIZE"]
121+
elif MPI.Is_initialized():
122+
world_size = MPI.COMM_WORLD.Get_size()
123+
else:
124+
world_size = 1
125+
os.environ["WORLD_SIZE"] = str(world_size)
126+
127+
if "RANK" not in os.environ:
128+
if "PMI_RANK" in os.environ:
129+
rank = os.environ["PMI_RANK"]
130+
elif MPI.Is_initialized():
131+
rank = MPI.COMM_WORLD.Get_rank()
132+
else:
133+
rank = 0
134+
os.environ["RANK"] = str(rank)
135+
136+
# Initialize DDP module
137+
dist.init_process_group(backend = "nccl", init_method='env://')
138+
139+
elif method == "gloo":
140+
if "MASTER_ADDR" not in os.environ:
141+
# check if OpenMPI is used
142+
if "PMIX_SERVER_URI2" in os.environ:
143+
addr = os.environ["PMIX_SERVER_URI2"]
144+
addr = addr.split("//")[1].split(":")[0]
145+
os.environ["MASTER_ADDR"] = addr
146+
else:
147+
os.environ['MASTER_ADDR'] = "localhost"
148+
149+
# Use the default pytorch port
150+
if "MASTER_PORT" not in os.environ:
151+
os.environ["MASTER_PORT"] = "29500"
152+
153+
# obtain WORLD_SIZE
154+
if "WORLD_SIZE" not in os.environ:
155+
# check if OpenMPI is used
156+
if "OMPI_COMM_WORLD_SIZE" in os.environ:
157+
world_size = os.environ["OMPI_COMM_WORLD_SIZE"]
158+
elif "PMI_SIZE" in os.environ:
159+
world_size = os.environ["PMI_SIZE"]
160+
elif MPI.Is_initialized():
161+
world_size = MPI.COMM_WORLD.Get_size()
162+
else:
163+
world_size = 1
164+
os.environ["WORLD_SIZE"] = str(world_size)
165+
166+
# obtain RANK
167+
if "RANK" not in os.environ:
168+
# check if OpenMPI is used
169+
if "OMPI_COMM_WORLD_RANK" in os.environ:
170+
rank = os.environ["OMPI_COMM_WORLD_RANK"]
171+
elif "PMI_RANK" in os.environ:
172+
rank = os.environ["PMI_RANK"]
173+
elif MPI.Is_initialized():
174+
rank = MPI.COMM_WORLD.Get_rank()
175+
else:
176+
rank = 0
177+
os.environ["RANK"] = str(rank)
178+
179+
# Initialize DDP module
180+
dist.init_process_group(backend = "gloo", init_method='env://')
181+
182+
else:
183+
raise NotImplementedError()
184+
185+
def finalize(self):
186+
dist.destroy_process_group()
187+
188+
#----< init_parallel() >-------------------------------------------------------
189+
def init_parallel():
190+
# check if cuda device is available
191+
ngpu_per_node = torch.cuda.device_count()
192+
if not torch.cuda.is_available():
193+
backend = "gloo"
194+
else:
195+
backend = "nccl-mpich"
196+
197+
# initialize parallel/distributed environment
198+
comm = distributed(backend)
199+
rank = comm.get_rank()
200+
world_size = comm.get_size()
201+
local_rank = comm.get_local_rank()
202+
203+
# select training device: cpu or cuda
204+
if not torch.cuda.is_available():
205+
device = torch.device("cpu")
206+
else:
207+
device = torch.device("cuda:"+str(local_rank))
208+
209+
return comm, device
210+
211+

examples/MNIST/MNIST_codes/main.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import torch.optim as optim
66
from torchvision import datasets, transforms
77
from torch.optim.lr_scheduler import StepLR
8-
8+
import comm_file
9+
from torch.nn.parallel import DistributedDataParallel as DDP
10+
from torch.distributed import ReduceOp, all_reduce
911

1012
class Net(nn.Module):
1113
def __init__(self):
@@ -33,40 +35,62 @@ def forward(self, x):
3335
return output
3436

3537

36-
def train(args, model, device, train_loader, optimizer, epoch):
38+
def train(args, model, device, train_loader, optimizer, epoch, comm):
3739
model.train()
40+
total_loss = 0.0
41+
num_batches = 0
3842
for batch_idx, (data, target) in enumerate(train_loader):
3943
data, target = data.to(device), target.to(device)
4044
optimizer.zero_grad()
4145
output = model(data)
4246
loss = F.nll_loss(output, target)
4347
loss.backward()
4448
optimizer.step()
45-
if batch_idx % args.log_interval == 0:
46-
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
47-
epoch, batch_idx * len(data), len(train_loader.dataset),
48-
100. * batch_idx / len(train_loader), loss.item()))
49-
if args.dry_run:
50-
break
51-
52-
53-
def test(model, device, test_loader):
49+
50+
total_loss += loss.item()
51+
num_batches += 1
52+
53+
# Compute the average loss for the current epoch
54+
avg_loss = total_loss / num_batches
55+
56+
# Reduce the average loss across all processes
57+
avg_loss_tensor = torch.tensor(avg_loss, device=device)
58+
all_reduce(avg_loss_tensor, op=ReduceOp.SUM)
59+
avg_loss_tensor /= comm.get_size()
60+
61+
# Print the average loss only from the master process
62+
if comm.get_rank() == 0:
63+
print(f'Train Epoch: {epoch}\tAverage Loss: {avg_loss_tensor.item():.6f}')
64+
65+
66+
def test(model, device, test_loader, comm):
5467
model.eval()
5568
test_loss = 0
5669
correct = 0
70+
total_samples = 0
5771
with torch.no_grad():
5872
for data, target in test_loader:
5973
data, target = data.to(device), target.to(device)
6074
output = model(data)
6175
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
6276
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
6377
correct += pred.eq(target.view_as(pred)).sum().item()
64-
65-
test_loss /= len(test_loader.dataset)
66-
67-
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
68-
test_loss, correct, len(test_loader.dataset),
69-
100. * correct / len(test_loader.dataset)))
78+
total_samples += data.size(0)
79+
80+
test_loss_tensor = torch.tensor(test_loss, device=device)
81+
correct_tensor = torch.tensor(correct, device=device)
82+
total_samples_tensor = torch.tensor(total_samples, device=device)
83+
all_reduce(test_loss_tensor, op=ReduceOp.SUM)
84+
all_reduce(correct_tensor, op=ReduceOp.SUM)
85+
all_reduce(total_samples_tensor, op=ReduceOp.SUM)
86+
test_loss = test_loss_tensor.item()
87+
correct = correct_tensor.item()
88+
total_samples = total_samples_tensor.item()
89+
avg_loss = test_loss / total_samples
90+
accuracy = 100. * correct / total_samples
91+
92+
if comm.get_rank() == 0:
93+
print(f'Test set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total_samples} ({accuracy:.0f}%)\n')
7094

7195

7296
def main():
@@ -100,45 +124,59 @@ def main():
100124

101125
torch.manual_seed(args.seed)
102126

103-
if use_cuda:
104-
device = torch.device("cuda")
105-
elif use_mps:
106-
device = torch.device("mps")
107-
else:
108-
device = torch.device("cpu")
127+
## init comm, rank, nprocs
128+
comm, device = comm_file.init_parallel()
129+
130+
rank = comm.get_rank()
131+
nprocs = comm.get_size()
132+
133+
print("nprocs = ", nprocs, " rank = ",rank," device = ", device)
109134

110135
train_kwargs = {'batch_size': args.batch_size}
111136
test_kwargs = {'batch_size': args.test_batch_size}
112137
if use_cuda:
113138
cuda_kwargs = {'num_workers': 1,
114139
'pin_memory': True,
115-
'shuffle': True}
140+
'shuffle': False}
116141
train_kwargs.update(cuda_kwargs)
117142
test_kwargs.update(cuda_kwargs)
118143

119144
transform=transforms.Compose([
120145
transforms.ToTensor(),
121146
transforms.Normalize((0.1307,), (0.3081,))
122147
])
123-
dataset1 = datasets.MNIST('../data', train=True, download=True,
148+
dataset1 = datasets.MNIST('../MNIST_data', train=True, download=True,
124149
transform=transform)
125-
dataset2 = datasets.MNIST('../data', train=False,
150+
dataset2 = datasets.MNIST('../MNIST_data', train=False,
126151
transform=transform)
127-
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
128-
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
152+
153+
# add train distributed sampler
154+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset1, num_replicas=comm.get_size(), rank=comm.get_rank(), shuffle=True)
155+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset2, num_replicas=comm.get_size(), rank=comm.get_rank(), shuffle=False)
156+
train_loader = torch.utils.data.DataLoader(dataset1, sampler=train_sampler, **train_kwargs)
157+
test_loader = torch.utils.data.DataLoader(dataset2, sampler=test_sampler, **test_kwargs, drop_last=False)
129158

130159
model = Net().to(device)
160+
# add to use DDP
161+
model = DDP(model, device_ids=[device] if use_cuda else None)
131162
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
132163

133164
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
134165
for epoch in range(1, args.epochs + 1):
135-
train(args, model, device, train_loader, optimizer, epoch)
136-
test(model, device, test_loader)
166+
# train sampler set epoch
167+
train_sampler.set_epoch(epoch)
168+
test_sampler.set_epoch(epoch)
169+
170+
train(args, model, device, train_loader, optimizer, epoch, comm)
171+
test(model, device, test_loader, comm)
137172
scheduler.step()
138173

139174
if args.save_model:
140-
torch.save(model.state_dict(), "mnist_cnn.pt")
175+
if rank == 0:
176+
torch.save(model.state_dict(), "mnist_cnn.pt")
177+
178+
comm.finalize()
141179

142180

143181
if __name__ == '__main__':
144-
main()
182+
main()

0 commit comments

Comments
 (0)