Skip to content

Commit 85a2212

Browse files
committed
add torch distributed
1 parent 0473a8e commit 85a2212

File tree

1 file changed

+77
-23
lines changed

1 file changed

+77
-23
lines changed

src/sagemaker/huggingface/estimator.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
import re
1818
from typing import Optional, Union, Dict
1919

20-
from sagemaker.deprecations import renamed_kwargs
2120
from sagemaker.estimator import Framework, EstimatorBase
2221
from sagemaker.fw_utils import (
2322
framework_name_from_image,
24-
warn_if_parameter_server_with_multi_gpu,
25-
validate_smdistributed,
23+
validate_distribution,
2624
)
2725
from sagemaker.huggingface.model import HuggingFaceModel
2826
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -37,6 +35,9 @@ class HuggingFace(Framework):
3735
"""Handle training of custom HuggingFace code."""
3836

3937
_framework_name = "huggingface"
38+
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
39+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
40+
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4041

4142
def __init__(
4243
self,
@@ -142,6 +143,36 @@ def __init__(
142143
}
143144
}
144145
146+
**To enable PyTorch DDP:**
147+
148+
.. code:: python
149+
150+
{
151+
"pytorchddp": {
152+
"enabled": True
153+
}
154+
}
155+
156+
To learn more, see `Distributed PyTorch Training
157+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
158+
159+
**To enable Torch Distributed:**
160+
161+
This is available for general distributed training on
162+
GPU instances from PyTorch v1.13.1 and later.
163+
164+
.. code:: python
165+
166+
{
167+
"torch_distributed": {
168+
"enabled": True
169+
}
170+
}
171+
172+
This option also supports distributed training on Trn1.
173+
To learn more, see `Distributed PyTorch Training on Trainium
174+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
175+
145176
To enable distributed training with
146177
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147178
for Hugging Face Transformers with PyTorch:
@@ -182,28 +213,23 @@ def __init__(
182213

183214
self._validate_args(image_uri=image_uri)
184215

185-
instance_type = renamed_kwargs(
186-
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
187-
)
188-
189-
base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
190-
base_framework_version = (
216+
self.base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
217+
self.base_framework_version = (
191218
tensorflow_version if tensorflow_version is not None else pytorch_version
192219
)
193220

194221
if distribution is not None:
195-
validate_smdistributed(
196-
instance_type=instance_type,
197-
framework_name=base_framework_name,
198-
framework_version=base_framework_version,
199-
py_version=self.py_version,
200-
distribution=distribution,
201-
image_uri=image_uri,
222+
distribution = validate_distribution(
223+
distribution,
224+
self.instance_groups,
225+
self.base_framework_name,
226+
self.base_framework_version,
227+
py_version,
228+
image_uri,
229+
kwargs,
202230
)
203231

204-
warn_if_parameter_server_with_multi_gpu(
205-
training_instance_type=instance_type, distribution=distribution
206-
)
232+
self.distribution = distribution or {}
207233

208234
if "enable_sagemaker_metrics" not in kwargs:
209235
kwargs["enable_sagemaker_metrics"] = True
@@ -214,8 +240,6 @@ def __init__(
214240
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
215241
)
216242

217-
self.distribution = distribution or {}
218-
219243
if compiler_config is not None:
220244
if not isinstance(compiler_config, TrainingCompilerConfig):
221245
error_string = (
@@ -267,14 +291,44 @@ def _validate_args(self, image_uri):
267291
"transformers_version, tensorflow_version and pytorch_version."
268292
)
269293

294+
def _huggingface_distribution_configuration(self, distribution):
295+
"""Returns a dict of distribution config for Hugging Face training
296+
297+
Args:
298+
distribution (dict): A dictionary with information on how to run distributed training.
299+
Returns:
300+
dict containing Pytorch DDP config
301+
"""
302+
distribution_config = {}
303+
pytorch_ddp_enabled = False
304+
torch_distributed_enabled = False
305+
306+
if "pytorchddp" in distribution:
307+
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
308+
elif "torch_distributed" in distribution:
309+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
310+
311+
if pytorch_ddp_enabled:
312+
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
313+
if self.instance_type is not None:
314+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
315+
elif torch_distributed_enabled:
316+
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
317+
if self.instance_type is not None:
318+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
319+
else:
320+
distribution_config = self._distribution_configuration(distribution=distribution)
321+
322+
return distribution_config
323+
270324
def hyperparameters(self):
271325
"""Return hyperparameters used by your custom PyTorch code during model training."""
272326
hyperparameters = super(HuggingFace, self).hyperparameters()
273-
distributed_training_hyperparameters = self._distribution_configuration(
327+
additional_hyperparameters = self._huggingface_distribution_configuration(
274328
distribution=self.distribution
275329
)
276330
hyperparameters.update(
277-
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
331+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
278332
)
279333

280334
if self.compiler_config:

0 commit comments

Comments
 (0)