Skip to content

Commit 73ade20

Browse files
author
Kewei Wang
committed
set device and remove shuffle for training with CUDA
1 parent 25e92d9 commit 73ade20

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

examples/MNIST/mnist.patch

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,29 @@
6262
args = parser.parse_args()
6363
use_cuda = not args.no_cuda and torch.cuda.is_available()
6464
use_mps = not args.no_mps and torch.backends.mps.is_available()
65-
@@ -107,7 +122,7 @@
65+
@@ -101,18 +116,18 @@
66+
torch.manual_seed(args.seed)
67+
68+
if use_cuda:
69+
- device = torch.device("cuda")
70+
+ torch.cuda.set_device(rank) # Set the GPU device by rank
71+
+ device = torch.device(f"cuda:{rank}")
72+
elif use_mps:
73+
device = torch.device("mps")
6674
else:
6775
device = torch.device("cpu")
68-
76+
6977
- train_kwargs = {'batch_size': args.batch_size}
7078
+ train_kwargs = {'batch_size': args.batch_size//nprocs}
7179
test_kwargs = {'batch_size': args.test_batch_size}
7280
if use_cuda:
7381
cuda_kwargs = {'num_workers': 1,
82+
- 'pin_memory': True,
83+
- 'shuffle': True}
84+
+ 'pin_memory': True}
85+
train_kwargs.update(cuda_kwargs)
86+
test_kwargs.update(cuda_kwargs)
87+
7488
@@ -120,25 +135,53 @@
7589
transforms.ToTensor(),
7690
transforms.Normalize((0.1307,), (0.3081,))

0 commit comments

Comments
 (0)