Skip to content

Commit a636155

Browse files
authored
change: Allow either instance_type or instance_group to be defined in… (#4232)
1 parent 5ff1fca commit a636155

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,7 @@ def _distribution_configuration(self, distribution):
38183818

38193819
mpi_enabled = False
38203820
smdataparallel_enabled = False
3821+
p5_enabled = False
38213822
if "instance_groups" in distribution:
38223823
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
38233824
"instance_groups"
@@ -3862,10 +3863,11 @@ def _distribution_configuration(self, distribution):
38623863
elif isinstance(self.instance_type, str):
38633864
p5_enabled = "p5.48xlarge" in self.instance_type
38643865
else:
3865-
raise ValueError(
3866-
"Invalid object type for instance_type argument. Expected "
3867-
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868-
)
3866+
for instance in self.instance_groups:
3867+
if "p5.48xlarge" in instance._to_request_dict().get("InstanceType", ()):
3868+
p5_enabled = True
3869+
break
3870+
38693871
img_uri = "" if self.image_uri is None else self.image_uri
38703872
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38713873
if (

tests/unit/test_estimator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,39 @@ def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session)
402402
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
403403

404404

405+
def test_validate_smdistributed_instance_groups_raises(sagemaker_session):
406+
instance_group_1 = InstanceGroup("train_group", "ml.p4d.24xlarge", 2)
407+
instance_group_2 = InstanceGroup("train_group", "ml.p5.48xlarge", 2)
408+
f = DummyFramework(
409+
"some_script.py",
410+
role="DummyRole",
411+
instance_groups=[instance_group_1, instance_group_2],
412+
sagemaker_session=sagemaker_session,
413+
output_path="outputpath",
414+
image_uri="some_acceptable_image",
415+
)
416+
# Testing instance_group with p5 raises exception
417+
with pytest.raises(ValueError):
418+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
419+
with pytest.raises(ValueError):
420+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
421+
422+
423+
def test_validate_smdistributed_instance_groups_not_raises(sagemaker_session):
424+
instance_group_1 = InstanceGroup("train_group", "ml.p4d.24xlarge", 2)
425+
f = DummyFramework(
426+
"some_script.py",
427+
role="DummyRole",
428+
instance_groups=[instance_group_1],
429+
sagemaker_session=sagemaker_session,
430+
output_path="outputpath",
431+
image_uri="some_acceptable_image",
432+
)
433+
# Testing instance_group without p5 does not raise exception
434+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
435+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
436+
437+
405438
def test_framework_all_init_args(sagemaker_session):
406439
f = DummyFramework(
407440
"my_script.py",

0 commit comments

Comments
 (0)