@@ -196,6 +196,7 @@ fit Optional Arguments
196
196
- ``logs ``: Defaults to True, whether to show logs produced by training
197
197
job in the Python session. Only meaningful when wait is True.
198
198
199
+ ----
199
200
200
201
Distributed PyTorch Training
201
202
============================
@@ -262,15 +263,19 @@ during the PyTorch DDP initialization.
262
263
263
264
.. note ::
264
265
265
- The SageMaker PyTorch estimator can operates both ``mpirun `` and ``torchrun `` in the backend for distributed training.
266
+ The SageMaker PyTorch estimator can operate both ``mpirun `` (for PyTorch 1.12.0 and later)
267
+ and ``torchrun ``
268
+ (for PyTorch 1.13.1 and later) in the backend for distributed training.
266
269
267
270
For more information about setting up PyTorch DDP in your training script,
268
271
see `Getting Started with Distributed Data Parallel
269
272
<https://pytorch.org/tutorials/intermediate/ddp_tutorial.html> `_ in the
270
273
PyTorch documentation.
271
274
272
- The following example shows how to run a PyTorch DDP training in SageMaker
273
- using two ``ml.p4d.24xlarge `` instances:
275
+ The following examples show how to set a PyTorch estimator
276
+ to run a distributed training job on two ``ml.p4d.24xlarge `` instances.
277
+
278
+ **Using PyTorch DDP with the ``mpirun`` backend **
274
279
275
280
.. code :: python
276
281
@@ -290,7 +295,27 @@ using two ``ml.p4d.24xlarge`` instances:
290
295
}
291
296
)
292
297
293
- pt_estimator.fit(" s3://bucket/path/to/training/data" )
298
+ **Using PyTorch DDP with the ``torchrun`` backend **
299
+
300
+ .. code :: python
301
+
302
+ from sagemaker.pytorch import PyTorch
303
+
304
+ pt_estimator = PyTorch(
305
+ entry_point = " train_ptddp.py" ,
306
+ role = " SageMakerRole" ,
307
+ framework_version = " 1.13.1" ,
308
+ py_version = " py38" ,
309
+ instance_count = 2 ,
310
+ instance_type = " ml.p4d.24xlarge" ,
311
+ distribution = {
312
+ " torch_distributed" : {
313
+ " enabled" : True
314
+ }
315
+ }
316
+ )
317
+
318
+ ----
294
319
295
320
.. _distributed-pytorch-training-on-trainium :
296
321
@@ -316,14 +341,14 @@ with the ``torch_distributed`` option as the distribution strategy.
316
341
.. note ::
317
342
318
343
This ``torch_distributed `` support is available
319
- in the AWS Deep Learning Containers for PyTorch Neuron starting v1.11.0 and other gpu instances starting v1.13.1 .
344
+ in the AWS Deep Learning Containers for PyTorch Neuron starting v1.11.0.
320
345
To find a complete list of supported versions of PyTorch Neuron, see
321
346
`Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers >`_
322
347
in the *AWS Deep Learning Containers GitHub repository *.
323
348
324
349
.. note ::
325
350
326
- SageMaker Debugger is currently not supported with Trn1 instances.
351
+ SageMaker Debugger is not compatible with Trn1 instances.
327
352
328
353
Adapt Your Training Script to Initialize with the XLA backend
329
354
-------------------------------------------------------------
0 commit comments