Skip to content

[huggingface] Add torch.distributed support for Trainium and torchrun #3759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
import re
from typing import Optional, Union, Dict

from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
warn_if_parameter_server_with_multi_gpu,
validate_smdistributed,
validate_distribution,
)
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand All @@ -37,6 +35,9 @@ class HuggingFace(Framework):
"""Handle training of custom HuggingFace code."""

_framework_name = "huggingface"
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"

def __init__(
self,
Expand Down Expand Up @@ -142,6 +143,36 @@ def __init__(
}
}

**To enable PyTorch DDP:**

.. code:: python

{
"pytorchddp": {
"enabled": True
}
}

To learn more, see `Distributed PyTorch Training
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.

**To enable Torch Distributed:**

This is available for general distributed training on
GPU instances from PyTorch v1.13.1 and later.

.. code:: python

{
"torch_distributed": {
"enabled": True
}
}

This option also supports distributed training on Trn1.
To learn more, see `Distributed PyTorch Training on Trainium
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.

To enable distributed training with
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
for Hugging Face Transformers with PyTorch:
Expand Down Expand Up @@ -182,29 +213,6 @@ def __init__(

self._validate_args(image_uri=image_uri)

instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)

base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
base_framework_version = (
tensorflow_version if tensorflow_version is not None else pytorch_version
)

if distribution is not None:
validate_smdistributed(
instance_type=instance_type,
framework_name=base_framework_name,
framework_version=base_framework_version,
py_version=self.py_version,
distribution=distribution,
image_uri=image_uri,
)

warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)

if "enable_sagemaker_metrics" not in kwargs:
kwargs["enable_sagemaker_metrics"] = True

Expand All @@ -214,6 +222,25 @@ def __init__(
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)

if "entry_point" not in kwargs:
kwargs["entry_point"] = entry_point

self.base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
self.base_framework_version = (
tensorflow_version if tensorflow_version is not None else pytorch_version
)

if distribution is not None:
distribution = validate_distribution(
distribution,
self.instance_groups,
self.base_framework_name,
self.base_framework_version,
py_version,
image_uri,
kwargs,
)

self.distribution = distribution or {}

if compiler_config is not None:
Expand Down Expand Up @@ -267,14 +294,44 @@ def _validate_args(self, image_uri):
"transformers_version, tensorflow_version and pytorch_version."
)

def _huggingface_distribution_configuration(self, distribution):
"""Returns a dict of distribution config for Hugging Face training

Args:
distribution (dict): A dictionary with information on how to run distributed training.
Returns:
dict containing Pytorch DDP config
"""
distribution_config = {}
pytorch_ddp_enabled = False
torch_distributed_enabled = False

if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
elif "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)

if pytorch_ddp_enabled:
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
elif torch_distributed_enabled:
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
else:
distribution_config = self._distribution_configuration(distribution=distribution)

return distribution_config

def hyperparameters(self):
"""Return hyperparameters used by your custom PyTorch code during model training."""
hyperparameters = super(HuggingFace, self).hyperparameters()
distributed_training_hyperparameters = self._distribution_configuration(
additional_hyperparameters = self._huggingface_distribution_configuration(
distribution=self.distribution
)
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)

if self.compiler_config:
Expand Down
50 changes: 50 additions & 0 deletions tests/integ/test_huggingface_torch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
from sagemaker.huggingface import HuggingFace
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, timeout


def test_huggingface_torch_distributed_g5_glue(
sagemaker_session,
huggingface_training_latest_version,
huggingface_training_pytorch_latest_version,
huggingface_pytorch_latest_training_py_version,
):
with timeout.timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
data_path = os.path.join(DATA_DIR, "huggingface")
estimator = HuggingFace(
py_version=huggingface_pytorch_latest_training_py_version,
entry_point=os.path.join(data_path, "run_glue.py"),
role="SageMakerRole",
transformers_version=huggingface_training_latest_version,
pytorch_version=huggingface_training_pytorch_latest_version,
instance_count=1,
instance_type="ml.g5.12xlarge",
hyperparameters={
"model_name_or_path": "distilbert-base-cased",
"task_name": "wnli",
"do_train": True,
"do_eval": True,
"max_seq_length": 128,
"fp16": True,
"per_device_train_batch_size": 32,
"output_dir": "/opt/ml/model",
},
distribution={"torch_distributed": {"enabled": True}},
sagemaker_session=sagemaker_session,
disable_profiler=True,
)
estimator.fit()