Skip to content

Commit 73877ba

Browse files
committed
fix: deepcopy distribution in validate_distribution
1 parent bfc63d2 commit 73877ba

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

src/sagemaker/fw_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
import shutil
2222
import tempfile
2323
from collections import namedtuple
24-
from typing import Optional, Union, Dict
24+
from copy import deepcopy
25+
from typing import List, Optional, Union, Dict
2526
from packaging import version
2627

2728
import sagemaker.image_uris
29+
from sagemaker.instance_group import InstanceGroup
2830
from sagemaker.s3_utils import s3_path_join
2931
from sagemaker.session_settings import SessionSettings
3032
import sagemaker.utils
@@ -828,14 +830,14 @@ def _validate_smdataparallel_args(
828830

829831

830832
def validate_distribution(
831-
distribution,
832-
instance_groups,
833-
framework_name,
834-
framework_version,
835-
py_version,
836-
image_uri,
837-
kwargs,
838-
):
833+
distribution: Dict,
834+
instance_groups: List[InstanceGroup],
835+
framework_name: str,
836+
framework_version: str,
837+
py_version: str,
838+
image_uri: str,
839+
kwargs: Dict,
840+
) -> Dict:
839841
"""Check if distribution strategy is correctly invoked by the user.
840842
841843
Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
@@ -872,7 +874,9 @@ def validate_distribution(
872874
strategy-specific inputs are incorrect/unsupported or
873875
heterogeneous cluster set up is incorrect
874876
"""
875-
train_instance_groups = distribution.get("instance_groups", [])
877+
validated_distribution = deepcopy(distribution)
878+
879+
train_instance_groups = validated_distribution.get("instance_groups", [])
876880
if instance_groups is None:
877881
if len(train_instance_groups) >= 1:
878882
# if estimator's instance_groups is not defined but
@@ -902,77 +906,77 @@ def validate_distribution(
902906
instance_type = train_instance_group.instance_type
903907
validate_distribution_for_instance_type(
904908
instance_type=instance_type,
905-
distribution=distribution,
909+
distribution=validated_distribution,
906910
)
907911
validate_smdistributed(
908912
instance_type=instance_type,
909913
framework_name=framework_name,
910914
framework_version=framework_version,
911915
py_version=py_version,
912-
distribution=distribution,
916+
distribution=validated_distribution,
913917
image_uri=image_uri,
914918
)
915919
if framework_name and framework_name == "pytorch":
916920
# We need to validate only for PyTorch framework
917921
validate_pytorch_distribution(
918-
distribution=distribution,
922+
distribution=validated_distribution,
919923
framework_name=framework_name,
920924
framework_version=framework_version,
921925
py_version=py_version,
922926
image_uri=image_uri,
923927
)
924928
validate_torch_distributed_distribution(
925929
instance_type=instance_type,
926-
distribution=distribution,
930+
distribution=validated_distribution,
927931
framework_version=framework_version,
928932
py_version=py_version,
929933
image_uri=image_uri,
930934
entry_point=kwargs["entry_point"],
931935
)
932936
warn_if_parameter_server_with_multi_gpu(
933-
training_instance_type=instance_type, distribution=distribution
937+
training_instance_type=instance_type, distribution=validated_distribution
934938
)
935939
# get instance group names
936940
instance_group_names.append(train_instance_group.instance_group_name)
937-
distribution["instance_groups"] = instance_group_names
941+
validated_distribution["instance_groups"] = instance_group_names
938942
else:
939943
# in this case, we are handling a normal training job (without heterogeneous cluster)
940944
instance_type = renamed_kwargs(
941945
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
942946
)
943947
validate_distribution_for_instance_type(
944948
instance_type=instance_type,
945-
distribution=distribution,
949+
distribution=validated_distribution,
946950
)
947951
validate_smdistributed(
948952
instance_type=instance_type,
949953
framework_name=framework_name,
950954
framework_version=framework_version,
951955
py_version=py_version,
952-
distribution=distribution,
956+
distribution=validated_distribution,
953957
image_uri=image_uri,
954958
)
955959
if framework_name and framework_name == "pytorch":
956960
# We need to validate only for PyTorch framework
957961
validate_pytorch_distribution(
958-
distribution=distribution,
962+
distribution=validated_distribution,
959963
framework_name=framework_name,
960964
framework_version=framework_version,
961965
py_version=py_version,
962966
image_uri=image_uri,
963967
)
964968
validate_torch_distributed_distribution(
965969
instance_type=instance_type,
966-
distribution=distribution,
970+
distribution=validated_distribution,
967971
framework_version=framework_version,
968972
py_version=py_version,
969973
image_uri=image_uri,
970974
entry_point=kwargs["entry_point"],
971975
)
972976
warn_if_parameter_server_with_multi_gpu(
973-
training_instance_type=instance_type, distribution=distribution
977+
training_instance_type=instance_type, distribution=validated_distribution
974978
)
975-
return distribution
979+
return validated_distribution
976980

977981

978982
def validate_distribution_for_instance_type(instance_type, distribution):

tests/unit/test_fw_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,28 @@ def test_validate_distribution_raises():
784784
)
785785

786786

787+
def test_validate_distribution_deepcopy():
788+
train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1)
789+
instance_groups = [train_group]
790+
framework = "tensorflow"
791+
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
792+
validated = fw_utils.validate_distribution(
793+
distribution,
794+
instance_groups,
795+
framework,
796+
None,
797+
None,
798+
"custom-container",
799+
{"entry_point": "train.py"},
800+
)
801+
802+
assert validated == {
803+
"instance_groups": ["train_group"],
804+
"smdistributed": {"dataparallel": {"enabled": True}},
805+
}
806+
assert validated is not distribution
807+
808+
787809
def test_validate_smdistributed_not_raises():
788810
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
789811
smdataparallel_enabled_custom_mpi = {

tests/unit/test_processing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,18 @@ def test_processor_with_required_parameters(sagemaker_session):
853853
sagemaker_session.process.assert_called_with(**expected_args)
854854

855855

856+
def test_processor_with_underscore_image_name(sagemaker_session):
857+
processor = Processor(
858+
role=ROLE,
859+
image_uri="1234567890.dkr.ecr.eu-west-1.amazonaws.com/my_project/my_image_with_underscores:latest",
860+
instance_count=1,
861+
instance_type="ml.m4.xlarge",
862+
sagemaker_session=sagemaker_session,
863+
)
864+
865+
processor.run()
866+
867+
856868
def test_processor_with_missing_network_config_parameters(sagemaker_session):
857869
processor = Processor(
858870
role=ROLE,

0 commit comments

Comments
 (0)