Skip to content

change: allow smdistributed to be enabled with torch_distributed. #4129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -3198,6 +3199,7 @@ class Framework(EstimatorBase):
"""

_framework_name = None
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121", "2.0-gpu-py310-cu121")

def __init__(
self,
Expand Down Expand Up @@ -3843,16 +3845,43 @@ def _distribution_configuration(self, distribution):
"custom_mpi_options", ""
)

if get_mp_parameters(distribution):
distribution_config["mp_parameters"] = get_mp_parameters(distribution)

elif "modelparallel" in distribution.get("smdistributed", {}):
raise ValueError("Cannot use Model Parallelism without MPI enabled!")

if "smdistributed" in distribution:
# smdistributed strategy selected
if get_mp_parameters(distribution):
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
# first make sure torch_distributed is enabled if instance type is p5
torch_distributed_enabled = False
if "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get(
"enabled", False
)
smdistributed = distribution["smdistributed"]
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
if isinstance(self.instance_type, ParameterString):
p5_enabled = "p5.48xlarge" in self.instance_type.default_value
elif isinstance(self.instance_type, str):
p5_enabled = "p5.48xlarge" in self.instance_type
else:
raise ValueError(
"Invalid object type for instance_type argument. Expected "
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
)
img_uri = "" if self.image_uri is None else self.image_uri
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
if (
unsupported_image in img_uri and not torch_distributed_enabled
): # disabling DLC images with CUDA12
raise ValueError(
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
"(Could be due to CUDA version being greater than 11.)"
)
if (
not torch_distributed_enabled and p5_enabled
): # disabling p5 when torch distributed is disabled
raise ValueError(
"SMModelParallel and SMDataParallel currently do not support p5 instances."
)
# smdistributed strategy selected with supported instance type
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
distribution_config[self.INSTANCE_TYPE] = self.instance_type
if smdataparallel_enabled:
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
elif torch_distributed_enabled:
if "smdistributed" in distribution:
# Enable torch_distributed for smdistributed.
distribution_config = self._distribution_configuration(distribution=distribution)
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
Expand Down
94 changes: 93 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,20 @@
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
}
DISTRIBUTION_SM_DDP_ENABLED = {
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}}
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
"torch_distributed": {"enabled": False},
}
DISTRIBUTION_SM_DDP_DISABLED = {
"smdistributed": {"enabled": True},
"torch_distributed": {"enabled": False},
}
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED = {
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
"torch_distributed": {"enabled": True},
}
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED = {
"smdistributed": {"enabled": True},
"torch_distributed": {"enabled": True},
}
MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir"
_DEFINITION_CONFIG = PipelineDefinitionConfig(use_custom_job_prefix=False)
Expand Down Expand Up @@ -310,6 +323,85 @@ def training_job_description(sagemaker_session):
return returned_job_description


def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
# Test unsupported image raises error.
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
# Fail due to unsupported CUDA12 DLC image.
f = DummyFramework(
"some_script.py",
role="DummyRole",
instance_type="ml.p4d.24xlarge",
sagemaker_session=sagemaker_session,
output_path="outputpath",
image_uri=unsupported_image,
)
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)

# Test unsupported image with suffix raises error.
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
# Fail due to unsupported CUDA12 DLC image.
f = DummyFramework(
"some_script.py",
role="DummyRole",
instance_type="ml.p4d.24xlarge",
sagemaker_session=sagemaker_session,
output_path="outputpath",
image_uri=unsupported_image + "-ubuntu20.04-sagemaker-pr-3303",
)
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)


def test_validate_smdistributed_p5_raises(sagemaker_session):
# Supported DLC image.
f = DummyFramework(
"some_script.py",
role="DummyRole",
instance_type="ml.p5.48xlarge",
sagemaker_session=sagemaker_session,
output_path="outputpath",
image_uri="some_acceptable_image",
)
# Both fail because instance type is p5 and torch_distributed is off.
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
with pytest.raises(ValueError):
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)


def test_validate_smdistributed_p5_not_raises(sagemaker_session):
f = DummyFramework(
"some_script.py",
role="DummyRole",
instance_type="ml.p5.48xlarge",
sagemaker_session=sagemaker_session,
output_path="outputpath",
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303",
)
# Testing with p5 instance and torch_distributed enabled.
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)


def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session):
f = DummyFramework(
"some_script.py",
role="DummyRole",
instance_type="ml.p4d.24xlarge",
sagemaker_session=sagemaker_session,
output_path="outputpath",
image_uri="some_acceptable_image",
)
# Testing backwards compatability with p4d instances.
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)


def test_framework_all_init_args(sagemaker_session):
f = DummyFramework(
"my_script.py",
Expand Down