Skip to content

Commit fe175ac

Browse files
committed
download mnist pytorch file and apply patch
* add patch file, mnist.patch * add Makefile to enable run 'make check' * put PnetCDF-IO part into a separate file, pnetcdf_io.py
1 parent da841d7 commit fe175ac

File tree

5 files changed

+413
-0
lines changed

5 files changed

+413
-0
lines changed

examples/MNIST/Makefile

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#
2+
# Copyright (C) 2024, Northwestern University and Argonne National Laboratory
3+
# See COPYRIGHT notice in top-level directory.
4+
#
5+
6+
check_PROGRAMS = mnist_main.py
7+
8+
MNIST_URL = https://raw.githubusercontent.com/pytorch/examples/main/mnist/main.py
9+
10+
mnist_main.py:
11+
curl -Ls $(MNIST_URL) -o $@
12+
patch -st $@ < mnist.patch
13+
14+
all:
15+
16+
ptests check: mnist_main.py mnist_images.nc
17+
@echo "======================================================================"
18+
@echo " examples/MNIST: Parallel testing on 4 MPI processes"
19+
@echo "======================================================================"
20+
@mpiexec -n 4 python mnist_main.py --batch-size 4 --test-batch-size 2 --epochs 3 --input-file mnist_images.nc
21+
@echo ""
22+
23+
clean:
24+
rm -rf mnist_main.py
25+

examples/MNIST/comm_file.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
import torch
3+
import torch.distributed as dist
4+
from mpi4py import MPI
5+
6+
class distributed():
7+
def get_size(self):
8+
if dist.is_available() and dist.is_initialized():
9+
size = dist.get_world_size()
10+
else:
11+
size = 1
12+
return size
13+
14+
def get_rank(self):
15+
if dist.is_available() and dist.is_initialized():
16+
rank = dist.get_rank()
17+
else:
18+
rank = 0
19+
return rank
20+
21+
def get_local_rank(self):
22+
if not (dist.is_available() and dist.is_initialized()):
23+
return 0
24+
# Number of GPUs per node
25+
if torch.cuda.is_available():
26+
local_rank = dist.get_rank() % torch.cuda.device_count()
27+
else:
28+
# raise NotImplementedError()
29+
# running on cpu device should not call this function
30+
local_rank = -1
31+
return local_rank
32+
33+
def __init__(self, method):
34+
# MASTER_PORT - required; has to be a free port on machine with rank 0
35+
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
36+
# WORLD_SIZE - required; can be set either here, or in a call to init function
37+
# RANK - required; can be set either here, or in a call to init function
38+
self.mpi_comm = MPI.COMM_WORLD
39+
40+
if method == "nccl-slurm":
41+
# MASTER_ADDR can be set in the slurm batch script using command
42+
# scontrol show hostnames $SLURM_JOB_NODELIST
43+
if "MASTER_ADDR" not in os.environ:
44+
# Try SLURM_LAUNCH_NODE_IPADDR but it is the IP address of the node
45+
# from which the task launch was initiated (where the srun command
46+
# ran from). It may not be the node of rank 0.
47+
if "SLURM_LAUNCH_NODE_IPADDR" in os.environ:
48+
os.environ["MASTER_ADDR"] = os.environ["SLURM_LAUNCH_NODE_IPADDR"]
49+
else:
50+
raise Exception("Error: nccl-slurm - SLURM_LAUNCH_NODE_IPADDR is not set")
51+
52+
# Use the default pytorch port
53+
if "MASTER_PORT" not in os.environ:
54+
if "SLURM_SRUN_COMM_PORT" in os.environ:
55+
os.environ["MASTER_PORT"] = os.environ["SLURM_SRUN_COMM_PORT"]
56+
else:
57+
os.environ["MASTER_PORT"] = "29500"
58+
59+
# obtain WORLD_SIZE
60+
if "WORLD_SIZE" not in os.environ:
61+
if "SLURM_NTASKS" in os.environ:
62+
world_size = os.environ["SLURM_NTASKS"]
63+
else:
64+
if "SLURM_JOB_NUM_NODES" in os.environ:
65+
num_nodes = os.environ["SLURM_JOB_NUM_NODES"]
66+
else:
67+
raise Exception("Error: nccl-slurm - SLURM_JOB_NUM_NODES is not set")
68+
if "SLURM_NTASKS_PER_NODE" in os.environ:
69+
ntasks_per_node = os.environ["SLURM_NTASKS_PER_NODE"]
70+
elif "SLURM_TASKS_PER_NODE" in os.environ:
71+
ntasks_per_node = os.environ["SLURM_TASKS_PER_NODE"]
72+
else:
73+
raise Exception("Error: nccl-slurm - SLURM_(N)TASKS_PER_NODE is not set")
74+
world_size = ntasks_per_node * num_nodes
75+
os.environ["WORLD_SIZE"] = str(world_size)
76+
77+
# obtain RANK
78+
if "RANK" not in os.environ:
79+
if "SLURM_PROCID" in os.environ:
80+
os.environ["RANK"] = os.environ["SLURM_PROCID"]
81+
else:
82+
raise Exception("Error: nccl-slurm - SLURM_PROCID is not set")
83+
84+
# Initialize DDP module
85+
dist.init_process_group(backend = "nccl", init_method='env://')
86+
87+
elif method == "nccl-openmpi":
88+
if "MASTER_ADDR" not in os.environ:
89+
if "PMIX_SERVER_URI2" in os.environ:
90+
os.environ["MASTER_ADDR"] = os.environ("PMIX_SERVER_URI2").split("//")[1]
91+
else:
92+
raise Exception("Error: nccl-openmpi - PMIX_SERVER_URI2 is not set")
93+
94+
# Use the default pytorch port
95+
if "MASTER_PORT" not in os.environ:
96+
os.environ["MASTER_PORT"] = "29500"
97+
98+
if "WORLD_SIZE" not in os.environ:
99+
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
100+
raise Exception("Error: nccl-openmpi - OMPI_COMM_WORLD_SIZE is not set")
101+
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
102+
103+
if "RANK" not in os.environ:
104+
if "OMPI_COMM_WORLD_RANK" not in os.environ:
105+
raise Exception("Error: nccl-openmpi - OMPI_COMM_WORLD_RANK is not set")
106+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
107+
108+
# Initialize DDP module
109+
dist.init_process_group(backend = "nccl", init_method='env://')
110+
111+
elif method == "nccl-mpich":
112+
if "MASTER_ADDR" not in os.environ:
113+
os.environ['MASTER_ADDR'] = "localhost"
114+
115+
# Use the default pytorch port
116+
if "MASTER_PORT" not in os.environ:
117+
os.environ["MASTER_PORT"] = "29500"
118+
119+
if "WORLD_SIZE" not in os.environ:
120+
if "PMI_SIZE" in os.environ:
121+
world_size = os.environ["PMI_SIZE"]
122+
elif MPI.Is_initialized():
123+
world_size = MPI.COMM_WORLD.Get_size()
124+
else:
125+
world_size = 1
126+
os.environ["WORLD_SIZE"] = str(world_size)
127+
128+
if "RANK" not in os.environ:
129+
if "PMI_RANK" in os.environ:
130+
rank = os.environ["PMI_RANK"]
131+
elif MPI.Is_initialized():
132+
rank = MPI.COMM_WORLD.Get_rank()
133+
else:
134+
rank = 0
135+
os.environ["RANK"] = str(rank)
136+
137+
# Initialize DDP module
138+
dist.init_process_group(backend = "nccl", init_method='env://')
139+
140+
elif method == "gloo":
141+
if "MASTER_ADDR" not in os.environ:
142+
# check if OpenMPI is used
143+
if "PMIX_SERVER_URI2" in os.environ:
144+
addr = os.environ["PMIX_SERVER_URI2"]
145+
addr = addr.split("//")[1].split(":")[0]
146+
os.environ["MASTER_ADDR"] = addr
147+
else:
148+
os.environ['MASTER_ADDR'] = "localhost"
149+
150+
# Use the default pytorch port
151+
if "MASTER_PORT" not in os.environ:
152+
os.environ["MASTER_PORT"] = "29500"
153+
154+
# obtain WORLD_SIZE
155+
if "WORLD_SIZE" not in os.environ:
156+
# check if OpenMPI is used
157+
if "OMPI_COMM_WORLD_SIZE" in os.environ:
158+
world_size = os.environ["OMPI_COMM_WORLD_SIZE"]
159+
elif "PMI_SIZE" in os.environ:
160+
world_size = os.environ["PMI_SIZE"]
161+
elif MPI.Is_initialized():
162+
world_size = MPI.COMM_WORLD.Get_size()
163+
else:
164+
world_size = 1
165+
os.environ["WORLD_SIZE"] = str(world_size)
166+
167+
# obtain RANK
168+
if "RANK" not in os.environ:
169+
# check if OpenMPI is used
170+
if "OMPI_COMM_WORLD_RANK" in os.environ:
171+
rank = os.environ["OMPI_COMM_WORLD_RANK"]
172+
elif "PMI_RANK" in os.environ:
173+
rank = os.environ["PMI_RANK"]
174+
elif MPI.Is_initialized():
175+
rank = MPI.COMM_WORLD.Get_rank()
176+
else:
177+
rank = 0
178+
os.environ["RANK"] = str(rank)
179+
180+
# Initialize DDP module
181+
dist.init_process_group(backend = "gloo", init_method='env://')
182+
183+
else:
184+
raise NotImplementedError()
185+
186+
def finalize(self):
187+
dist.destroy_process_group()
188+
189+
#----< init_parallel() >-------------------------------------------------------
190+
def init_parallel():
191+
# check if cuda device is available
192+
ngpu_per_node = torch.cuda.device_count()
193+
if not torch.cuda.is_available():
194+
backend = "gloo"
195+
else:
196+
backend = "nccl-mpich"
197+
198+
# initialize parallel/distributed environment
199+
comm = distributed(backend)
200+
rank = comm.get_rank()
201+
world_size = comm.get_size()
202+
local_rank = comm.get_local_rank()
203+
204+
# select training device: cpu or cuda
205+
if not torch.cuda.is_available():
206+
device = torch.device("cpu")
207+
else:
208+
device = torch.device("cuda:"+str(local_rank))
209+
210+
return comm, device
211+
212+

examples/MNIST/mnist.patch

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
--- mnist_main_original.py 2024-08-10 17:30:08.552324326 -0500
2+
+++ pnetcdf_mnist.py 2024-08-10 18:02:49.008705003 -0500
3+
@@ -1,3 +1,8 @@
4+
+#
5+
+# Copyright (C) 2024, Northwestern University and Argonne National Laboratory
6+
+# See COPYRIGHT notice in top-level directory.
7+
+#
8+
+
9+
import argparse
10+
import torch
11+
import torch.nn as nn
12+
@@ -5,7 +10,11 @@
13+
import torch.optim as optim
14+
from torchvision import datasets, transforms
15+
from torch.optim.lr_scheduler import StepLR
16+
+from torch.nn.parallel import DistributedDataParallel as DDP
17+
+from torch.utils.data.distributed import DistributedSampler
18+
19+
+import comm_file, pnetcdf_io
20+
+from mpi4py import MPI
21+
22+
class Net(nn.Module):
23+
def __init__(self):
24+
@@ -42,14 +51,13 @@
25+
loss = F.nll_loss(output, target)
26+
loss.backward()
27+
optimizer.step()
28+
- if batch_idx % args.log_interval == 0:
29+
+ if rank == 0 and batch_idx % args.log_interval == 0:
30+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
31+
epoch, batch_idx * len(data), len(train_loader.dataset),
32+
100. * batch_idx / len(train_loader), loss.item()))
33+
if args.dry_run:
34+
break
35+
36+
-
37+
def test(model, device, test_loader):
38+
model.eval()
39+
test_loss = 0
40+
@@ -62,9 +70,14 @@
41+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
42+
correct += pred.eq(target.view_as(pred)).sum().item()
43+
44+
+ # aggregate loss among all ranks
45+
+ test_loss = comm.mpi_comm.allreduce(test_loss, op=MPI.SUM)
46+
+ correct = comm.mpi_comm.allreduce(correct, op=MPI.SUM)
47+
+
48+
test_loss /= len(test_loader.dataset)
49+
50+
- print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
51+
+ if rank == 0:
52+
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
53+
test_loss, correct, len(test_loader.dataset),
54+
100. * correct / len(test_loader.dataset)))
55+
56+
@@ -94,6 +107,8 @@
57+
help='how many batches to wait before logging training status')
58+
parser.add_argument('--save-model', action='store_true', default=False,
59+
help='For Saving the current Model')
60+
+ parser.add_argument('--input-file', type=str, required=True,
61+
+ help='NetCDF file storing train and test samples')
62+
args = parser.parse_args()
63+
use_cuda = not args.no_cuda and torch.cuda.is_available()
64+
use_mps = not args.no_mps and torch.backends.mps.is_available()
65+
@@ -107,7 +122,7 @@
66+
else:
67+
device = torch.device("cpu")
68+
69+
- train_kwargs = {'batch_size': args.batch_size}
70+
+ train_kwargs = {'batch_size': args.batch_size//nprocs}
71+
test_kwargs = {'batch_size': args.test_batch_size}
72+
if use_cuda:
73+
cuda_kwargs = {'num_workers': 1,
74+
@@ -120,25 +135,53 @@
75+
transforms.ToTensor(),
76+
transforms.Normalize((0.1307,), (0.3081,))
77+
])
78+
- dataset1 = datasets.MNIST('../data', train=True, download=True,
79+
- transform=transform)
80+
- dataset2 = datasets.MNIST('../data', train=False,
81+
- transform=transform)
82+
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
83+
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
84+
+
85+
+ # Open files storing training and testing samples
86+
+ infile = args.input_file
87+
+ train_file = pnetcdf_io.dataset(infile, 'train_images', 'train_labels', transform, comm.mpi_comm)
88+
+ test_file = pnetcdf_io.dataset(infile, 'test_images', 'test_labels', transform, comm.mpi_comm)
89+
+
90+
+ # create distributed samplers
91+
+ train_sampler = DistributedSampler(train_file, num_replicas=nprocs, rank=rank, shuffle=True)
92+
+ test_sampler = DistributedSampler(test_file, num_replicas=nprocs, rank=rank, shuffle=False)
93+
+
94+
+ # add distributed samplers to DataLoaders
95+
+ train_loader = torch.utils.data.DataLoader(train_file, sampler=train_sampler, **train_kwargs)
96+
+ test_loader = torch.utils.data.DataLoader(test_file, sampler=test_sampler, **test_kwargs, drop_last=False)
97+
98+
model = Net().to(device)
99+
+
100+
+ # use DDP
101+
+ model = DDP(model, device_ids=[device] if use_cuda else None)
102+
+
103+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
104+
105+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
106+
for epoch in range(1, args.epochs + 1):
107+
+ # train sampler set epoch
108+
+ train_sampler.set_epoch(epoch)
109+
+ test_sampler.set_epoch(epoch)
110+
+
111+
train(args, model, device, train_loader, optimizer, epoch)
112+
test(model, device, test_loader)
113+
scheduler.step()
114+
115+
if args.save_model:
116+
- torch.save(model.state_dict(), "mnist_cnn.pt")
117+
+ if rank == 0:
118+
+ torch.save(model.state_dict(), "mnist_cnn.pt")
119+
120+
+ # close files
121+
+ train_file.close()
122+
+ test_file.close()
123+
124+
if __name__ == '__main__':
125+
+ ## initialize parallel environment
126+
+ comm, device = comm_file.init_parallel()
127+
+
128+
+ rank = comm.get_rank()
129+
+ nprocs = comm.get_size()
130+
+
131+
main()
132+
+
133+
+ comm.finalize()
134+
+

examples/MNIST/mnist_images.nc

55.7 KB
Binary file not shown.

0 commit comments

Comments
 (0)