File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change 101
101
)
102
102
from sagemaker .workflow import is_pipeline_variable
103
103
from sagemaker .workflow .entities import PipelineVariable
104
+ from sagemaker .workflow .parameters import ParameterString
104
105
from sagemaker .workflow .pipeline_context import PipelineSession , runnable_by_pipeline
105
106
106
107
logger = logging .getLogger (__name__ )
@@ -3856,7 +3857,15 @@ def _distribution_configuration(self, distribution):
3856
3857
)
3857
3858
smdistributed = distribution ["smdistributed" ]
3858
3859
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
3859
- p5_enabled = "p5.48xlarge" in self .instance_type
3860
+ if isinstance (self .instance_type , ParameterString ):
3861
+ p5_enabled = "p5.48xlarge" in self .instance_type .default_value
3862
+ elif isinstance (self .instance_type , str ):
3863
+ p5_enabled = "p5.48xlarge" in self .instance_type
3864
+ else :
3865
+ raise ValueError (
3866
+ "Invalid object type for instance_type argument. Expected "
3867
+ f"{ type (str )} or { type (ParameterString )} but got { type (self .instance_type )} ."
3868
+ )
3860
3869
img_uri = "" if self .image_uri is None else self .image_uri
3861
3870
for unsupported_image in Framework .UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM :
3862
3871
if (
You can’t perform that action at this time.
0 commit comments