Skip to content

Commit bcb4e4c

Browse files
author
Dewen Qi
committed
go with estimator subclasses
1 parent dcdf66f commit bcb4e4c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.pytorch import defaults
3232
from sagemaker.pytorch.model import PyTorchModel
3333
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
34+
from sagemaker.workflow import is_pipeline_variable
3435
from sagemaker.workflow.entities import PipelineVariable
3536

3637
logger = logging.getLogger("sagemaker")
@@ -51,7 +52,7 @@ def __init__(
5152
source_dir: Optional[Union[str, PipelineVariable]] = None,
5253
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
5354
image_uri: Optional[Union[str, PipelineVariable]] = None,
54-
distribution: Dict = None,
55+
distribution: Optional[Dict] = None,
5556
**kwargs
5657
):
5758
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -224,7 +225,7 @@ def __init__(
224225
if distribution is not None:
225226
instance_type = self._get_instance_type()
226227
# remove "ml." prefix
227-
if instance_type[:3] == "ml.":
228+
if not is_pipeline_variable(instance_type) and instance_type[:3] == "ml.":
228229
instance_type = instance_type[3:]
229230
validate_distribution_instance(self.sagemaker_session, distribution, instance_type)
230231

0 commit comments

Comments
 (0)