File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change 62
62
args = parser.parse_args()
63
63
use_cuda = not args.no_cuda and torch.cuda.is_available()
64
64
use_mps = not args.no_mps and torch.backends.mps.is_available()
65
- @@ -107,7 +122,7 @@
65
+ @@ -101,18 +116,18 @@
66
+ torch.manual_seed(args.seed)
67
+
68
+ if use_cuda:
69
+ - device = torch.device("cuda")
70
+ + torch.cuda.set_device(rank) # Set the GPU device by rank
71
+ + device = torch.device(f"cuda:{rank}")
72
+ elif use_mps:
73
+ device = torch.device("mps")
66
74
else:
67
75
device = torch.device("cpu")
68
-
76
+
69
77
- train_kwargs = {'batch_size': args.batch_size}
70
78
+ train_kwargs = {'batch_size': args.batch_size//nprocs}
71
79
test_kwargs = {'batch_size': args.test_batch_size}
72
80
if use_cuda:
73
81
cuda_kwargs = {'num_workers': 1,
82
+ - 'pin_memory': True,
83
+ - 'shuffle': True}
84
+ + 'pin_memory': True}
85
+ train_kwargs.update(cuda_kwargs)
86
+ test_kwargs.update(cuda_kwargs)
87
+
74
88
@@ -120,25 +135,53 @@
75
89
transforms.ToTensor(),
76
90
transforms.Normalize((0.1307,), (0.3081,))
You can’t perform that action at this time.
0 commit comments