Skip to content

Commit 67e8813

Browse files
rohithn1bhupendrasingh
authored andcommitted
fix: Allowing instance_type to be ParameterString object.
1 parent 4b86475 commit 67e8813

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from sagemaker.workflow import is_pipeline_variable
103103
from sagemaker.workflow.entities import PipelineVariable
104+
from sagemaker.workflow.parameters import ParameterString
104105
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
105106

106107
logger = logging.getLogger(__name__)
@@ -3856,7 +3857,15 @@ def _distribution_configuration(self, distribution):
38563857
)
38573858
smdistributed = distribution["smdistributed"]
38583859
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
3859-
p5_enabled = "p5.48xlarge" in self.instance_type
3860+
if isinstance(self.instance_type, ParameterString):
3861+
p5_enabled = "p5.48xlarge" in self.instance_type.default_value
3862+
elif isinstance(self.instance_type, str):
3863+
p5_enabled = "p5.48xlarge" in self.instance_type
3864+
else:
3865+
raise ValueError(
3866+
"Invalid object type for instance_type argument. Expected "
3867+
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868+
)
38603869
img_uri = "" if self.image_uri is None else self.image_uri
38613870
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38623871
if (

0 commit comments

Comments
 (0)