Skip to content

Commit 08b3093

Browse files
Yongyan Raonavinsoni
authored andcommitted
change: add a check to prevent launching a modelparallel job on CPU only instances.
1 parent 60872f3 commit 08b3093

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/sagemaker/fw_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,52 @@ def validate_distribution(
753753
return distribution
754754

755755

756+
def validate_distribution_instance(sagemaker_session, distribution, instance_type):
757+
"""Check to prevent launching a modelparallel job on CPU only instances.
758+
759+
Args:
760+
sagemaker_session (sagemaker.session.Session): Session object which
761+
manages interactions with Amazon SageMaker APIs and any other
762+
AWS services needed.
763+
distribution (dict): A dictionary with information to enable distributed training.
764+
distribution = {
765+
"smdistributed": {
766+
"modelparallel": {
767+
"enabled": True,
768+
"parameters": {
769+
...
770+
},
771+
},
772+
},
773+
...
774+
}
775+
instance_type (str): A string representing the type of training instance selected.
776+
777+
Raises:
778+
ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
779+
"""
780+
if "smdistributed" not in distribution:
781+
# Distribution strategy other than smdistributed is selected
782+
return
783+
784+
if "modelparallel" not in distribution["smdistributed"]:
785+
# Strategy other than modelparallel is selected
786+
return
787+
788+
if not distribution["smdistributed"]["modelparallel"]["enabled"]:
789+
# Strategy modelparallel is not enabled
790+
return
791+
792+
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
793+
InstanceTypes=[f"{instance_type}"]
794+
)
795+
if "GpuInfo" not in instance_desc["InstanceTypes"][0]:
796+
raise ValueError(
797+
f"modelparallel only runs on GPU-enabled instances. "
798+
f"{instance_type} does not support GPU."
799+
)
800+
801+
756802
def python_deprecation_warning(framework, latest_supported_version):
757803
"""Placeholder docstring"""
758804
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
2727
validate_distribution,
28+
validate_distribution_instance,
2829
)
2930
from sagemaker.pytorch import defaults
3031
from sagemaker.pytorch.model import PyTorchModel
@@ -205,6 +206,12 @@ def __init__(
205206
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
206207
)
207208
if distribution is not None:
209+
instance_type = self._get_instance_type()
210+
# remove "ml." prefix
211+
if instance_type[:3] == "ml.":
212+
instance_type = instance_type[3:]
213+
validate_distribution_instance(self.sagemaker_session, distribution, instance_type)
214+
208215
distribution = validate_distribution(
209216
distribution,
210217
self.instance_groups,

0 commit comments

Comments
 (0)