Skip to content

Commit e34ad16

Browse files
author
Dewen Qi
committed
change: Add PipelineVariable annotation to framework estimators
1 parent cb5c991 commit e34ad16

File tree

9 files changed

+136
-76
lines changed

9 files changed

+136
-76
lines changed

src/sagemaker/algorithm.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@
1313
"""Test docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union, Dict, List
17+
1618
import sagemaker
1719
import sagemaker.parameter
1820
from sagemaker import vpc_utils
1921
from sagemaker.deserializers import BytesDeserializer
2022
from sagemaker.deprecations import removed_kwargs
2123
from sagemaker.estimator import EstimatorBase
24+
from sagemaker.inputs import TrainingInput, FileSystemInput
2225
from sagemaker.serializers import IdentitySerializer
2326
from sagemaker.transformer import Transformer
2427
from sagemaker.predictor import Predictor
28+
from sagemaker.session import Session
29+
from sagemaker.workflow.entities import PipelineVariable
30+
31+
from sagemaker.workflow import is_pipeline_variable
2532

2633

2734
class AlgorithmEstimator(EstimatorBase):
@@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase):
3744

3845
def __init__(
3946
self,
40-
algorithm_arn,
41-
role,
42-
instance_count,
43-
instance_type,
44-
volume_size=30,
45-
volume_kms_key=None,
46-
max_run=24 * 60 * 60,
47-
input_mode="File",
48-
output_path=None,
49-
output_kms_key=None,
50-
base_job_name=None,
51-
sagemaker_session=None,
52-
hyperparameters=None,
53-
tags=None,
54-
subnets=None,
55-
security_group_ids=None,
56-
model_uri=None,
57-
model_channel_name="model",
58-
metric_definitions=None,
59-
encrypt_inter_container_traffic=False,
60-
use_spot_instances=False,
61-
max_wait=None,
47+
algorithm_arn: str,
48+
role: str,
49+
instance_count: Optional[Union[int, PipelineVariable]] = None,
50+
instance_type: Optional[Union[str, PipelineVariable]] = None,
51+
volume_size: Union[int, PipelineVariable] = 30,
52+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
53+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
54+
input_mode: Union[str, PipelineVariable] = "File",
55+
output_path: Optional[Union[str, PipelineVariable]] = None,
56+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
57+
base_job_name: Optional[str] = None,
58+
sagemaker_session: Optional[Session] = None,
59+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
60+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
61+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
62+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
63+
model_uri: Optional[str] = None,
64+
model_channel_name: Union[str, PipelineVariable] = "model",
65+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
66+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
67+
use_spot_instances: Union[bool, PipelineVariable] = False,
68+
max_wait: Union[int, PipelineVariable] = None,
6269
**kwargs # pylint: disable=W0613
6370
):
6471
"""Initialize an ``AlgorithmEstimator`` instance.
@@ -186,22 +193,25 @@ def validate_train_spec(self):
186193
# Check that the input mode provided is compatible with the training input modes for the
187194
# algorithm.
188195
input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"])
189-
if self.input_mode not in input_modes:
196+
if not is_pipeline_variable(self.input_mode) and self.input_mode not in input_modes:
190197
raise ValueError(
191198
"Invalid input mode: %s. %s only supports: %s"
192199
% (self.input_mode, algorithm_name, input_modes)
193200
)
194201

195202
# Check that the training instance type is compatible with the algorithm.
196203
supported_instances = train_spec["SupportedTrainingInstanceTypes"]
197-
if self.instance_type not in supported_instances:
204+
if (
205+
not is_pipeline_variable(self.instance_type)
206+
and self.instance_type not in supported_instances
207+
):
198208
raise ValueError(
199209
"Invalid instance_type: %s. %s supports the following instance types: %s"
200210
% (self.instance_type, algorithm_name, supported_instances)
201211
)
202212

203213
# Verify if distributed training is supported by the algorithm
204-
if (
214+
if not is_pipeline_variable(self.instance_count) and (
205215
self.instance_count > 1
206216
and "SupportsDistributedTraining" in train_spec
207217
and not train_spec["SupportsDistributedTraining"]
@@ -414,12 +424,18 @@ def _prepare_for_training(self, job_name=None):
414424

415425
super(AlgorithmEstimator, self)._prepare_for_training(job_name)
416426

417-
def fit(self, inputs=None, wait=True, logs=True, job_name=None):
427+
def fit(
428+
self,
429+
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
430+
wait: bool = True,
431+
logs: bool = True,
432+
job_name: Optional[str] = None,
433+
):
418434
"""Placeholder docstring"""
419435
if inputs:
420436
self._validate_input_channels(inputs)
421437

422-
super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name)
438+
return super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name)
423439

424440
def _validate_input_channels(self, channels):
425441
"""Placeholder docstring"""

src/sagemaker/chainer/estimator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Union, Optional
17+
from typing import Union, Optional, Dict
1818

1919
from sagemaker.estimator import Framework, EstimatorBase
2020
from sagemaker.fw_utils import (
@@ -34,26 +34,26 @@
3434
class Chainer(Framework):
3535
"""Handle end-to-end training and deployment of custom Chainer code."""
3636

37-
_framework_name = "chainer"
37+
_framework_name: str = "chainer"
3838

3939
# Hyperparameters
40-
_use_mpi = "sagemaker_use_mpi"
41-
_num_processes = "sagemaker_num_processes"
42-
_process_slots_per_host = "sagemaker_process_slots_per_host"
43-
_additional_mpi_options = "sagemaker_additional_mpi_options"
40+
_use_mpi: str = "sagemaker_use_mpi"
41+
_num_processes: str = "sagemaker_num_processes"
42+
_process_slots_per_host: str = "sagemaker_process_slots_per_host"
43+
_additional_mpi_options: str = "sagemaker_additional_mpi_options"
4444

4545
def __init__(
4646
self,
4747
entry_point: Union[str, PipelineVariable],
48-
use_mpi=None,
49-
num_processes=None,
50-
process_slots_per_host=None,
51-
additional_mpi_options=None,
48+
use_mpi: Optional[Union[bool, PipelineVariable]] = None,
49+
num_processes: Optional[Union[int, PipelineVariable]] = None,
50+
process_slots_per_host: Optional[Union[int, PipelineVariable]] = None,
51+
additional_mpi_options: Optional[Union[str, PipelineVariable]] = None,
5252
source_dir: Optional[Union[str, PipelineVariable]] = None,
53-
hyperparameters=None,
54-
framework_version=None,
55-
py_version=None,
56-
image_uri=None,
53+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
54+
framework_version: Optional[str] = None,
55+
py_version: Optional[str] = None,
56+
image_uri: Optional[Union[str, PipelineVariable]] = None,
5757
**kwargs
5858
):
5959
"""This ``Estimator`` executes an Chainer script in a managed execution environment.

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Union, Optional
17+
from typing import Union, Optional, Dict
1818

1919
from packaging.version import Version
2020

@@ -44,12 +44,12 @@ class PyTorch(Framework):
4444
def __init__(
4545
self,
4646
entry_point: Union[str, PipelineVariable],
47-
framework_version=None,
48-
py_version=None,
47+
framework_version: Optional[str] = None,
48+
py_version: Optional[str] = None,
4949
source_dir: Optional[Union[str, PipelineVariable]] = None,
50-
hyperparameters=None,
51-
image_uri=None,
52-
distribution=None,
50+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
51+
image_uri: Optional[Union[str, PipelineVariable]] = None,
52+
distribution: Optional[Dict] = None,
5353
**kwargs
5454
):
5555
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.

src/sagemaker/rl/estimator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import enum
1717
import logging
1818
import re
19-
from typing import Union, Optional
19+
from typing import Union, Optional, List, Dict
2020

2121
from sagemaker import image_uris, fw_utils
2222
from sagemaker.estimator import Framework, EstimatorBase
@@ -77,13 +77,13 @@ class RLEstimator(Framework):
7777
def __init__(
7878
self,
7979
entry_point: Union[str, PipelineVariable],
80-
toolkit=None,
81-
toolkit_version=None,
82-
framework=None,
80+
toolkit: Optional[RLToolkit] = None,
81+
toolkit_version: Optional[str] = None,
82+
framework: Optional[Framework] = None,
8383
source_dir: Optional[Union[str, PipelineVariable]] = None,
84-
hyperparameters=None,
85-
image_uri=None,
86-
metric_definitions=None,
84+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
85+
image_uri: Optional[Union[str, PipelineVariable]] = None,
86+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
8787
**kwargs
8888
):
8989
"""Creates an RLEstimator for managed Reinforcement Learning (RL).

src/sagemaker/sklearn/estimator.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Union, Optional
17+
from typing import Union, Optional, Dict
1818

1919
from sagemaker import image_uris
2020
from sagemaker.deprecations import renamed_kwargs
@@ -28,6 +28,7 @@
2828
from sagemaker.sklearn.model import SKLearnModel
2929
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3030
from sagemaker.workflow.entities import PipelineVariable
31+
from sagemaker.workflow import is_pipeline_variable
3132

3233
logger = logging.getLogger("sagemaker")
3334

@@ -40,12 +41,12 @@ class SKLearn(Framework):
4041
def __init__(
4142
self,
4243
entry_point: Union[str, PipelineVariable],
43-
framework_version=None,
44-
py_version="py3",
44+
framework_version: Optional[str] = None,
45+
py_version: str = "py3",
4546
source_dir: Optional[Union[str, PipelineVariable]] = None,
46-
hyperparameters=None,
47-
image_uri=None,
48-
image_uri_region=None,
47+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
48+
image_uri: Optional[Union[str, PipelineVariable]] = None,
49+
image_uri_region: Optional[str] = None,
4950
**kwargs
5051
):
5152
"""Creates a SKLearn Estimator for Scikit-learn environment.
@@ -128,11 +129,17 @@ def __init__(
128129
self.framework_version = framework_version
129130
self.py_version = py_version
130131

132+
if instance_type:
133+
if is_pipeline_variable(instance_type):
134+
raise ValueError("instance_count argument cannot be a pipeline variable")
131135
# SciKit-Learn does not support distributed training or training on GPU instance types.
132136
# Fail fast.
133137
_validate_not_gpu_instance_type(instance_type)
134138

135139
if instance_count:
140+
if is_pipeline_variable(instance_count):
141+
raise ValueError("instance_count argument cannot be a pipeline variable")
142+
136143
if instance_count != 1:
137144
raise AttributeError(
138145
"Scikit-Learn does not support distributed training. Please remove the "
@@ -148,6 +155,13 @@ def __init__(
148155
)
149156

150157
if image_uri is None:
158+
159+
if is_pipeline_variable(instance_type):
160+
raise ValueError(
161+
"instance_type argument cannot be a pipeline variable "
162+
"when image_uri is not given."
163+
)
164+
151165
self.image_uri = image_uris.retrieve(
152166
SKLearn._framework_name,
153167
image_uri_region or self.sagemaker_session.boto_region_name,

src/sagemaker/training_compiler/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from __future__ import absolute_import
1515
import logging
1616

17+
from sagemaker.workflow import is_pipeline_variable
18+
1719
logger = logging.getLogger(__name__)
1820

1921

@@ -132,7 +134,12 @@ def validate(
132134
ValueError: Raised if the requested configuration is not compatible
133135
with SageMaker Training Compiler.
134136
"""
135-
if estimator.instance_type:
137+
if is_pipeline_variable(estimator.instance_type):
138+
logger.warning(
139+
"Estimator instance_type is a PipelineVariable: %s, "
140+
"which has to be one of the [p3, g4dn, p4d, g5] in execution time."
141+
)
142+
elif estimator.instance_type:
136143
if "local" not in estimator.instance_type:
137144
requested_instance_class = estimator.instance_type.split(".")[
138145
1

src/sagemaker/workflow/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,10 @@ def _generate_step_map(
486486
"""Helper method to create a mapping from Step/Step Collection name to itself."""
487487
for step in steps:
488488
if step.name in step_map:
489-
raise ValueError("Pipeline steps cannot have duplicate names.")
489+
raise ValueError(
490+
"Pipeline steps cannot have duplicate names. In addition, steps added to "
491+
"the ConditionStep cannot be added to the Pipeline steps list."
492+
)
490493
step_map[step.name] = step
491494
if isinstance(step, ConditionStep):
492495
_generate_step_map(step.if_steps + step.else_steps, step_map)

src/sagemaker/xgboost/estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Union, Optional
17+
from typing import Union, Optional, Dict
1818

1919
from sagemaker import image_uris
2020
from sagemaker.deprecations import renamed_kwargs
@@ -45,12 +45,12 @@ class XGBoost(Framework):
4545
def __init__(
4646
self,
4747
entry_point: Union[str, PipelineVariable],
48-
framework_version,
48+
framework_version: str,
4949
source_dir: Optional[Union[str, PipelineVariable]] = None,
50-
hyperparameters=None,
51-
py_version="py3",
52-
image_uri=None,
53-
image_uri_region=None,
50+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
51+
py_version: str = "py3",
52+
image_uri: Optional[Union[str, PipelineVariable]] = None,
53+
image_uri_region: Optional[str] = None,
5454
**kwargs
5555
):
5656
"""An estimator that executes an XGBoost-based SageMaker Training Job.

0 commit comments

Comments
 (0)