Skip to content

Commit b5b0166

Browse files
Fix unit test failures
1 parent 76e9d2b commit b5b0166

File tree

4 files changed

+27
-29
lines changed

4 files changed

+27
-29
lines changed

src/sagemaker/fw_utils.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -859,27 +859,28 @@ def validate_distribution_for_instance_type(instance_type, distribution):
859859
instance_type (str): A string representing the type of training instance selected.
860860
distribution (dict): A dictionary with information to enable distributed training.
861861
"""
862-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
863862
err_msg = ""
864-
if match and match[1].startswith("trn"):
865-
keys = list(distribution.keys())
866-
if len(keys) == 0:
867-
return
868-
if len(keys) == 1:
869-
distribution_strategy = keys[0]
870-
if distribution_strategy != "torch_distributed":
863+
if isinstance(instance_type, str):
864+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
865+
if match and match[1].startswith("trn"):
866+
keys = list(distribution.keys())
867+
if len(keys) == 0:
868+
return
869+
if len(keys) == 1:
870+
distribution_strategy = keys[0]
871+
if distribution_strategy != "torch_distributed":
872+
err_msg += (
873+
f"Provided distribution strategy {distribution_strategy} is not supported"
874+
" for Trainium instances.\n"
875+
"Please specify one of the following supported distribution strategies:"
876+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
877+
)
878+
elif len(keys) > 1:
871879
err_msg += (
872-
f"Provided distribution strategy {distribution_strategy} is not supported for"
873-
" Trainium instances.\n"
880+
"Multiple distribution strategies are not supported for Trainium instances.\n"
874881
"Please specify one of the following supported distribution strategies:"
875-
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
882+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
876883
)
877-
elif len(keys) > 1:
878-
err_msg += (
879-
"Multiple distribution strategies are not supported for Trainium instances.\n"
880-
"Please specify one of the following supported distribution strategies:"
881-
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
882-
)
883884

884885
if err_msg:
885886
raise ValueError(err_msg)

src/sagemaker/pytorch/estimator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,13 @@ def __init__(
230230
if self.framework_version and Version(self.framework_version) >= Version("1.3"):
231231
kwargs["enable_sagemaker_metrics"] = True
232232

233-
if "entry_point" not in kwargs:
234-
kwargs["entry_point"] = entry_point
235-
236233
super(PyTorch, self).__init__(
237234
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
238235
)
236+
237+
if "entry_point" not in kwargs:
238+
kwargs["entry_point"] = entry_point
239+
239240
if distribution is not None:
240241
distribution = validate_distribution(
241242
distribution,

tests/integ/test_torch_distributed.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,9 @@
2424

2525

2626
@pytest.mark.skip(
27-
reason="This test is skipped for now due ML capacity error."
27+
reason="Disabling until the launch of SM Trainium containers"
2828
"This test should be re-enabled later."
2929
)
30-
@pytest.mark.skipif(
31-
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
32-
reason="Only allow this test to run in IAD and CMH to limit usage of ml.trn1.2xlarge",
33-
)
3430
def test_torch_distributed_trn1_pt_mnist(
3531
sagemaker_session,
3632
torch_distributed_framework_version,

tests/unit/test_fw_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def test_validate_distribution_not_raises():
655655
None, # framework_version
656656
None, # py_version
657657
"custom-container",
658-
{"instance_type": instance_type}, # kwargs
658+
{"instance_type": instance_type, "entry_point": "train.py"}, # kwargs
659659
)
660660

661661
for framework in frameworks:
@@ -683,7 +683,7 @@ def test_validate_distribution_not_raises():
683683
None, # framework_version
684684
None, # py_version
685685
"custom-container",
686-
{}, # kwargs
686+
{"entry_point": "train.py"}, # kwargs
687687
)
688688

689689

@@ -723,7 +723,7 @@ def test_validate_distribution_raises():
723723
None, # framework_version
724724
None, # py_version
725725
"custom-container",
726-
{"instance_type": instance_type}, # kwargs
726+
{"instance_type": instance_type, "entry_point": "train.py"}, # kwargs
727727
)
728728

729729
for framework in frameworks:
@@ -952,7 +952,7 @@ def test_validate_torch_distributed_not_raises():
952952

953953
# Case 1: Framework is PyTorch, but distribution is not torch_distributed
954954
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
955-
fw_utils.validate_pytorch_distribution(
955+
fw_utils.validate_torch_distributed_distribution(
956956
instance_type="ml.trn1.2xlarge",
957957
distribution=torch_distributed_disabled,
958958
framework_version="1.11.0",

0 commit comments

Comments
 (0)