|
21 | 21 | import shutil
|
22 | 22 | import tempfile
|
23 | 23 | from collections import namedtuple
|
24 |
| -from typing import Optional, Union, Dict |
| 24 | +from copy import deepcopy |
| 25 | +from typing import List, Optional, Union, Dict |
25 | 26 | from packaging import version
|
26 | 27 |
|
27 | 28 | import sagemaker.image_uris
|
| 29 | +from sagemaker.instance_group import InstanceGroup |
28 | 30 | from sagemaker.s3_utils import s3_path_join
|
29 | 31 | from sagemaker.session_settings import SessionSettings
|
30 | 32 | import sagemaker.utils
|
@@ -828,14 +830,14 @@ def _validate_smdataparallel_args(
|
828 | 830 |
|
829 | 831 |
|
830 | 832 | def validate_distribution(
|
831 |
| - distribution, |
832 |
| - instance_groups, |
833 |
| - framework_name, |
834 |
| - framework_version, |
835 |
| - py_version, |
836 |
| - image_uri, |
837 |
| - kwargs, |
838 |
| -): |
| 833 | + distribution: Dict, |
| 834 | + instance_groups: List[InstanceGroup], |
| 835 | + framework_name: str, |
| 836 | + framework_version: str, |
| 837 | + py_version: str, |
| 838 | + image_uri: str, |
| 839 | + kwargs: Dict, |
| 840 | +) -> Dict: |
839 | 841 | """Check if distribution strategy is correctly invoked by the user.
|
840 | 842 |
|
841 | 843 | Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
|
@@ -872,7 +874,9 @@ def validate_distribution(
|
872 | 874 | strategy-specific inputs are incorrect/unsupported or
|
873 | 875 | heterogeneous cluster set up is incorrect
|
874 | 876 | """
|
875 |
| - train_instance_groups = distribution.get("instance_groups", []) |
| 877 | + validated_distribution = deepcopy(distribution) |
| 878 | + |
| 879 | + train_instance_groups = validated_distribution.get("instance_groups", []) |
876 | 880 | if instance_groups is None:
|
877 | 881 | if len(train_instance_groups) >= 1:
|
878 | 882 | # if estimator's instance_groups is not defined but
|
@@ -902,77 +906,77 @@ def validate_distribution(
|
902 | 906 | instance_type = train_instance_group.instance_type
|
903 | 907 | validate_distribution_for_instance_type(
|
904 | 908 | instance_type=instance_type,
|
905 |
| - distribution=distribution, |
| 909 | + distribution=validated_distribution, |
906 | 910 | )
|
907 | 911 | validate_smdistributed(
|
908 | 912 | instance_type=instance_type,
|
909 | 913 | framework_name=framework_name,
|
910 | 914 | framework_version=framework_version,
|
911 | 915 | py_version=py_version,
|
912 |
| - distribution=distribution, |
| 916 | + distribution=validated_distribution, |
913 | 917 | image_uri=image_uri,
|
914 | 918 | )
|
915 | 919 | if framework_name and framework_name == "pytorch":
|
916 | 920 | # We need to validate only for PyTorch framework
|
917 | 921 | validate_pytorch_distribution(
|
918 |
| - distribution=distribution, |
| 922 | + distribution=validated_distribution, |
919 | 923 | framework_name=framework_name,
|
920 | 924 | framework_version=framework_version,
|
921 | 925 | py_version=py_version,
|
922 | 926 | image_uri=image_uri,
|
923 | 927 | )
|
924 | 928 | validate_torch_distributed_distribution(
|
925 | 929 | instance_type=instance_type,
|
926 |
| - distribution=distribution, |
| 930 | + distribution=validated_distribution, |
927 | 931 | framework_version=framework_version,
|
928 | 932 | py_version=py_version,
|
929 | 933 | image_uri=image_uri,
|
930 | 934 | entry_point=kwargs["entry_point"],
|
931 | 935 | )
|
932 | 936 | warn_if_parameter_server_with_multi_gpu(
|
933 |
| - training_instance_type=instance_type, distribution=distribution |
| 937 | + training_instance_type=instance_type, distribution=validated_distribution |
934 | 938 | )
|
935 | 939 | # get instance group names
|
936 | 940 | instance_group_names.append(train_instance_group.instance_group_name)
|
937 |
| - distribution["instance_groups"] = instance_group_names |
| 941 | + validated_distribution["instance_groups"] = instance_group_names |
938 | 942 | else:
|
939 | 943 | # in this case, we are handling a normal training job (without heterogeneous cluster)
|
940 | 944 | instance_type = renamed_kwargs(
|
941 | 945 | "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
|
942 | 946 | )
|
943 | 947 | validate_distribution_for_instance_type(
|
944 | 948 | instance_type=instance_type,
|
945 |
| - distribution=distribution, |
| 949 | + distribution=validated_distribution, |
946 | 950 | )
|
947 | 951 | validate_smdistributed(
|
948 | 952 | instance_type=instance_type,
|
949 | 953 | framework_name=framework_name,
|
950 | 954 | framework_version=framework_version,
|
951 | 955 | py_version=py_version,
|
952 |
| - distribution=distribution, |
| 956 | + distribution=validated_distribution, |
953 | 957 | image_uri=image_uri,
|
954 | 958 | )
|
955 | 959 | if framework_name and framework_name == "pytorch":
|
956 | 960 | # We need to validate only for PyTorch framework
|
957 | 961 | validate_pytorch_distribution(
|
958 |
| - distribution=distribution, |
| 962 | + distribution=validated_distribution, |
959 | 963 | framework_name=framework_name,
|
960 | 964 | framework_version=framework_version,
|
961 | 965 | py_version=py_version,
|
962 | 966 | image_uri=image_uri,
|
963 | 967 | )
|
964 | 968 | validate_torch_distributed_distribution(
|
965 | 969 | instance_type=instance_type,
|
966 |
| - distribution=distribution, |
| 970 | + distribution=validated_distribution, |
967 | 971 | framework_version=framework_version,
|
968 | 972 | py_version=py_version,
|
969 | 973 | image_uri=image_uri,
|
970 | 974 | entry_point=kwargs["entry_point"],
|
971 | 975 | )
|
972 | 976 | warn_if_parameter_server_with_multi_gpu(
|
973 |
| - training_instance_type=instance_type, distribution=distribution |
| 977 | + training_instance_type=instance_type, distribution=validated_distribution |
974 | 978 | )
|
975 |
| - return distribution |
| 979 | + return validated_distribution |
976 | 980 |
|
977 | 981 |
|
978 | 982 | def validate_distribution_for_instance_type(instance_type, distribution):
|
|
0 commit comments