@@ -196,6 +196,7 @@ fit Optional Arguments
196
196
- ``logs ``: Defaults to True, whether to show logs produced by training
197
197
job in the Python session. Only meaningful when wait is True.
198
198
199
+ ----
199
200
200
201
Distributed PyTorch Training
201
202
============================
@@ -262,16 +263,18 @@ during the PyTorch DDP initialization.
262
263
263
264
.. note ::
264
265
265
- The SageMaker PyTorch estimator operates ``mpirun `` in the backend.
266
- It doesn’t use ``torchrun `` for distributed training.
266
+ The SageMaker PyTorch estimator can operate both ``mpirun `` (for PyTorch 1.12.0 and later)
267
+ and ``torchrun `` (for PyTorch 1.13.1 and later) in the backend for distributed training.
267
268
268
269
For more information about setting up PyTorch DDP in your training script,
269
270
see `Getting Started with Distributed Data Parallel
270
271
<https://pytorch.org/tutorials/intermediate/ddp_tutorial.html> `_ in the
271
272
PyTorch documentation.
272
273
273
- The following example shows how to run a PyTorch DDP training in SageMaker
274
- using two ``ml.p4d.24xlarge `` instances:
274
+ The following examples show how to set a PyTorch estimator
275
+ to run a distributed training job on two ``ml.p4d.24xlarge `` instances.
276
+
277
+ **Using PyTorch DDP with the mpirun backend **
275
278
276
279
.. code :: python
277
280
@@ -291,7 +294,34 @@ using two ``ml.p4d.24xlarge`` instances:
291
294
}
292
295
)
293
296
294
- pt_estimator.fit(" s3://bucket/path/to/training/data" )
297
+ **Using PyTorch DDP with the torchrun backend **
298
+
299
+ .. code :: python
300
+
301
+ from sagemaker.pytorch import PyTorch
302
+
303
+ pt_estimator = PyTorch(
304
+ entry_point = " train_ptddp.py" ,
305
+ role = " SageMakerRole" ,
306
+ framework_version = " 1.13.1" ,
307
+ py_version = " py38" ,
308
+ instance_count = 2 ,
309
+ instance_type = " ml.p4d.24xlarge" ,
310
+ distribution = {
311
+ " torch_distributed" : {
312
+ " enabled" : True
313
+ }
314
+ }
315
+ )
316
+
317
+
318
+ .. note ::
319
+
320
+ For more information about setting up ``torchrun `` in your training script,
321
+ see `torchrun (Elastic Launch) <https://pytorch.org/docs/stable/elastic/run.html >`_ in *the
322
+ PyTorch documentation *.
323
+
324
+ ----
295
325
296
326
.. _distributed-pytorch-training-on-trainium :
297
327
@@ -324,7 +354,7 @@ with the ``torch_distributed`` option as the distribution strategy.
324
354
325
355
.. note ::
326
356
327
- SageMaker Debugger is currently not supported with Trn1 instances.
357
+ SageMaker Debugger is not compatible with Trn1 instances.
328
358
329
359
Adapt Your Training Script to Initialize with the XLA backend
330
360
-------------------------------------------------------------
0 commit comments