Skip to content

Commit 4bb1676

Browse files
committed
change: Allow either instance_type or instance_group to be defined in distributed training config.
1 parent 78c478c commit 4bb1676

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,7 @@ def _distribution_configuration(self, distribution):
38183818

38193819
mpi_enabled = False
38203820
smdataparallel_enabled = False
3821+
p5_enabled = False
38213822
if "instance_groups" in distribution:
38223823
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
38233824
"instance_groups"
@@ -3862,10 +3863,11 @@ def _distribution_configuration(self, distribution):
38623863
elif isinstance(self.instance_type, str):
38633864
p5_enabled = "p5.48xlarge" in self.instance_type
38643865
else:
3865-
raise ValueError(
3866-
"Invalid object type for instance_type argument. Expected "
3867-
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868-
)
3866+
if self.instance_groups is not None:
3867+
for instance in self.instance_groups:
3868+
if "p5.48xlarge" in instance._to_request_dict()["InstanceType"]:
3869+
p5_enabled = True
3870+
38693871
img_uri = "" if self.image_uri is None else self.image_uri
38703872
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38713873
if (

0 commit comments

Comments
 (0)