You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Use this guide to learn about the SageMaker distributed
5
+
Use this guide to learn how to use the SageMaker distributed
6
6
data parallel library API for PyTorch.
7
7
8
8
.. contents:: Topics
@@ -21,63 +21,73 @@ The distributed data parallel library works as a backend of the PyTorch distribu
21
21
See `SageMaker distributed data parallel PyTorch examples <https://sagemaker-examples.readthedocs.io/en/latest/training/distributed_training/index.html#pytorch-distributed>`__
22
22
for additional details on how to use the library.
23
23
24
-
1. Import the SageMaker distributed data parallel library’s PyTorch client.
24
+
1. Import the SageMaker distributed data parallel library’s PyTorch client.
from torch.nn.parallel import DistributedDataParallel asDDP
34
+
import torch
35
+
import torch.distributed as dist
36
+
from torch.nn.parallel import DistributedDataParallel asDDP
37
37
38
-
3. Set the backend of torch.distributed as smddp.
38
+
3. Set the backend of ``torch.distributed`` as ``smddp``.
39
39
40
-
.. code:: python
40
+
.. code:: python
41
41
42
-
dist.init_process_group(backend='smddp')
42
+
dist.init_process_group(backend='smddp')
43
43
44
-
4. After parsing arguments and defining a batch size parameter (for example, batch_size=args.batch_size), add a two-line of code to resize the batch size per worker (GPU). PyTorch's DataLoader operation does not automatically handle the batch resizing for distributed training.
44
+
4. After parsing arguments and defining a batch size parameter
45
+
(for example, ``batch_size=args.batch_size``), add a two-line of code to
46
+
resize the batch size per worker (GPU). PyTorch's DataLoader operation
47
+
does not automatically handle the batch resizing for distributed training.
45
48
46
-
.. code:: python
49
+
.. code:: python
47
50
48
-
batch_size //= dist.get_world_size()
49
-
batch_size =max(batch_size, 1)
51
+
batch_size //= dist.get_world_size()
52
+
batch_size =max(batch_size, 1)
50
53
51
-
5. Pin each GPU to a single SageMaker data parallel library process with local_rank—this refers to the relative rank of the process within a given node.
54
+
5. Pin each GPU to a single SageMaker data parallel library process with
55
+
``local_rank``. This refers to the relative rank of the process within a given node.
52
56
53
-
You can retreive the rank of the process from the LOCAL_RANK environment variable.
57
+
You can retrieve the rank of the process from the ``LOCAL_RANK`` environment variable.
54
58
55
-
.. code:: python
59
+
.. code:: python
56
60
57
-
import os
58
-
local_rank = os.environ["LOCAL_RANK"]
59
-
torch.cuda.set_device(local_rank)
61
+
import os
62
+
local_rank = os.environ["LOCAL_RANK"]
63
+
torch.cuda.set_device(local_rank)
60
64
61
-
6. After defining a model, wrap it with the PyTorch DDP.
65
+
6. After defining a model, wrap it with the PyTorch DDP.
62
66
63
-
.. code:: python
67
+
.. code:: python
64
68
65
-
model =...
69
+
model =...
66
70
67
-
# Wrap the model with the PyTorch DistributedDataParallel API
68
-
model = DDP(model)
71
+
# Wrap the model with the PyTorch DistributedDataParallel API
72
+
model = DDP(model)
69
73
70
-
7. When you call the torch.utils.data.distributed.DistributedSampler API, specify the total number of processes (GPUs) participating in training across all the nodes in the cluster. This is called world_size, and you can retrieve the number from the torch.distributed.get_world_size() API. Also, specify the rank of each process among all processes using the torch.distributed.get_rank() API.
74
+
7. When you call the ``torch.utils.data.distributed.DistributedSampler`` API,
75
+
specify the total number of processes (GPUs) participating in training across
76
+
all the nodes in the cluster. This is called ``world_size``, and you can retrieve
77
+
the number from the ``torch.distributed.get_world_size()`` API. Also, specify
78
+
the rank of each process among all processes using the ``torch.distributed.get_rank()`` API.
71
79
72
-
.. code:: python
80
+
.. code:: python
73
81
74
-
train_sampler = DistributedSampler(
75
-
train_dataset,
76
-
num_replicas= dist.get_world_size(),
77
-
rank= dist.get_rank()
78
-
)
82
+
train_sampler = DistributedSampler(
83
+
train_dataset,
84
+
num_replicas= dist.get_world_size(),
85
+
rank= dist.get_rank()
86
+
)
79
87
80
-
8. Modify your script to save checkpoints only on the leader process (rank 0). The leader process has a synchronized model. This also avoids other processes overwriting the checkpoints and possibly corrupting the checkpoints.
88
+
8. Modify your script to save checkpoints only on the leader process (rank 0).
89
+
The leader process has a synchronized model. This also avoids other processes
90
+
overwriting the checkpoints and possibly corrupting the checkpoints.
81
91
82
92
The following example code shows the structure of a PyTorch training script with DDP and smddp as the backend.
83
93
@@ -142,7 +152,7 @@ The following example code shows the structure of a PyTorch training script with
142
152
test(...)
143
153
scheduler.step()
144
154
145
-
# SageMaker data parallel: Save model on the main node (rank 0).
155
+
# SageMaker data parallel: Save model on the leader node (rank 0).
146
156
if dist.get_rank() ==0:
147
157
torch.save(...)
148
158
@@ -171,16 +181,16 @@ that are supported in the library v1.3.0 and before.
171
181
172
182
.. warning::
173
183
174
-
The following ``smdistributed`` APIs for its implementation of distributed data parallelism
175
-
for PyTorch is deprecated.
184
+
The following APIs for ``smdistributed`` implementation of the PyTorch distributed modules
0 commit comments