@@ -700,6 +700,50 @@ def validate_distribution(
700
700
return distribution
701
701
702
702
703
+ def validate_distribution_instance (sagemaker_session , distribution , instance_type ):
704
+ """Check to prevent launching a modelparallel job on CPU only instances.
705
+
706
+ Args:
707
+ sagemaker_session (sagemaker.session.Session): Session object which
708
+ manages interactions with Amazon SageMaker APIs and any other
709
+ AWS services needed.
710
+ distribution (dict): A dictionary with information to enable distributed training.
711
+ distribution = {
712
+ "smdistributed": {
713
+ "modelparallel": {
714
+ "enabled": True,
715
+ "parameters": {
716
+ ...
717
+ },
718
+ },
719
+ },
720
+ ...
721
+ }
722
+ instance_type (str): A string representing the type of training instance selected.
723
+
724
+ Raises:
725
+ ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
726
+ """
727
+ if "smdistributed" not in distribution :
728
+ # Distribution strategy other than smdistributed is selected
729
+ return
730
+
731
+ if "modelparallel" not in distribution ["smdistributed" ]:
732
+ # Strategy other than modelparallel is selected
733
+ return
734
+
735
+ if not distribution ["smdistributed" ]["modelparallel" ]["enabled" ]:
736
+ # Strategy modelparallel is not enabled
737
+ return
738
+
739
+ instance_desc = sagemaker_session .boto_session .client ("ec2" ).describe_instance_types (
740
+ InstanceTypes = [f"{ instance_type } " ]
741
+ )
742
+ if "GpuInfo" not in instance_desc ["InstanceTypes" ][0 ]:
743
+ raise ValueError (f"modelparallel only runs on GPU-enabled instances. "
744
+ f"{ instance_type } does not support GPU." )
745
+
746
+
703
747
def python_deprecation_warning (framework , latest_supported_version ):
704
748
"""Placeholder docstring"""
705
749
return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments