31
31
from sagemaker .pytorch import defaults
32
32
from sagemaker .pytorch .model import PyTorchModel
33
33
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
34
+ from sagemaker .workflow import is_pipeline_variable
34
35
from sagemaker .workflow .entities import PipelineVariable
35
36
36
37
logger = logging .getLogger ("sagemaker" )
@@ -51,7 +52,7 @@ def __init__(
51
52
source_dir : Optional [Union [str , PipelineVariable ]] = None ,
52
53
hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
53
54
image_uri : Optional [Union [str , PipelineVariable ]] = None ,
54
- distribution : Dict = None ,
55
+ distribution : Optional [ Dict ] = None ,
55
56
** kwargs
56
57
):
57
58
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -224,7 +225,7 @@ def __init__(
224
225
if distribution is not None :
225
226
instance_type = self ._get_instance_type ()
226
227
# remove "ml." prefix
227
- if instance_type [:3 ] == "ml." :
228
+ if not is_pipeline_variable ( instance_type ) and instance_type [:3 ] == "ml." :
228
229
instance_type = instance_type [3 :]
229
230
validate_distribution_instance (self .sagemaker_session , distribution , instance_type )
230
231
0 commit comments