Skip to content

Commit 839dc31

Browse files
jerrypeng7773Dewen Qi
authored andcommitted
change: Add PipelineVariable annotation for the rest of processors and estimators
annotations for processors + estimators / test mechanism change remove debug print reformatting
1 parent 0acb88c commit 839dc31

36 files changed

+920
-700
lines changed

src/sagemaker/algorithm.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@
1313
"""Test docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union, Dict, List
17+
1618
import sagemaker
1719
import sagemaker.parameter
1820
from sagemaker import vpc_utils
1921
from sagemaker.deserializers import BytesDeserializer
2022
from sagemaker.deprecations import removed_kwargs
2123
from sagemaker.estimator import EstimatorBase
24+
from sagemaker.inputs import TrainingInput, FileSystemInput
2225
from sagemaker.serializers import IdentitySerializer
2326
from sagemaker.transformer import Transformer
2427
from sagemaker.predictor import Predictor
28+
from sagemaker.session import Session
29+
from sagemaker.workflow.entities import PipelineVariable
30+
31+
from sagemaker.workflow import is_pipeline_variable
2532

2633

2734
class AlgorithmEstimator(EstimatorBase):
@@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase):
3744

3845
def __init__(
3946
self,
40-
algorithm_arn,
41-
role,
42-
instance_count,
43-
instance_type,
44-
volume_size=30,
45-
volume_kms_key=None,
46-
max_run=24 * 60 * 60,
47-
input_mode="File",
48-
output_path=None,
49-
output_kms_key=None,
50-
base_job_name=None,
51-
sagemaker_session=None,
52-
hyperparameters=None,
53-
tags=None,
54-
subnets=None,
55-
security_group_ids=None,
56-
model_uri=None,
57-
model_channel_name="model",
58-
metric_definitions=None,
59-
encrypt_inter_container_traffic=False,
60-
use_spot_instances=False,
61-
max_wait=None,
47+
algorithm_arn: str,
48+
role: str,
49+
instance_count: Optional[Union[int, PipelineVariable]] = None,
50+
instance_type: Optional[Union[str, PipelineVariable]] = None,
51+
volume_size: Union[int, PipelineVariable] = 30,
52+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
53+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
54+
input_mode: Union[str, PipelineVariable] = "File",
55+
output_path: Optional[Union[str, PipelineVariable]] = None,
56+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
57+
base_job_name: Optional[str] = None,
58+
sagemaker_session: Optional[Session] = None,
59+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
60+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
61+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
62+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
63+
model_uri: Optional[str] = None,
64+
model_channel_name: Union[str, PipelineVariable] = "model",
65+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
66+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
67+
use_spot_instances: Union[bool, PipelineVariable] = False,
68+
max_wait: Union[int, PipelineVariable] = None,
6269
**kwargs # pylint: disable=W0613
6370
):
6471
"""Initialize an ``AlgorithmEstimator`` instance.
@@ -186,22 +193,25 @@ def validate_train_spec(self):
186193
# Check that the input mode provided is compatible with the training input modes for the
187194
# algorithm.
188195
input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"])
189-
if self.input_mode not in input_modes:
196+
if not is_pipeline_variable(self.input_mode) and self.input_mode not in input_modes:
190197
raise ValueError(
191198
"Invalid input mode: %s. %s only supports: %s"
192199
% (self.input_mode, algorithm_name, input_modes)
193200
)
194201

195202
# Check that the training instance type is compatible with the algorithm.
196203
supported_instances = train_spec["SupportedTrainingInstanceTypes"]
197-
if self.instance_type not in supported_instances:
204+
if (
205+
not is_pipeline_variable(self.instance_type)
206+
and self.instance_type not in supported_instances
207+
):
198208
raise ValueError(
199209
"Invalid instance_type: %s. %s supports the following instance types: %s"
200210
% (self.instance_type, algorithm_name, supported_instances)
201211
)
202212

203213
# Verify if distributed training is supported by the algorithm
204-
if (
214+
if not is_pipeline_variable(self.instance_count) and (
205215
self.instance_count > 1
206216
and "SupportsDistributedTraining" in train_spec
207217
and not train_spec["SupportsDistributedTraining"]
@@ -414,12 +424,18 @@ def _prepare_for_training(self, job_name=None):
414424

415425
super(AlgorithmEstimator, self)._prepare_for_training(job_name)
416426

417-
def fit(self, inputs=None, wait=True, logs=True, job_name=None):
427+
def fit(
428+
self,
429+
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
430+
wait: bool = True,
431+
logs: bool = True,
432+
job_name: Optional[str] = None,
433+
):
418434
"""Placeholder docstring"""
419435
if inputs:
420436
self._validate_input_channels(inputs)
421437

422-
super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name)
438+
return super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name)
423439

424440
def _validate_input_channels(self, channels):
425441
"""Placeholder docstring"""

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union, Dict
17+
1618
import json
1719
import logging
1820
import tempfile
@@ -28,6 +30,9 @@
2830
from sagemaker.inputs import FileSystemInput, TrainingInput
2931
from sagemaker.utils import sagemaker_timestamp
3032
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
33+
from sagemaker.workflow.entities import PipelineVariable
34+
from sagemaker.workflow.parameters import ParameterBoolean
35+
from sagemaker.workflow import is_pipeline_variable
3136

3237
logger = logging.getLogger(__name__)
3338

@@ -40,16 +45,16 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4045

4146
feature_dim = hp("feature_dim", validation.gt(0), data_type=int)
4247
mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int)
43-
repo_name = None
44-
repo_version = None
48+
repo_name: Optional[str] = None
49+
repo_version: Optional[str] = None
4550

4651
def __init__(
4752
self,
48-
role,
49-
instance_count=None,
50-
instance_type=None,
51-
data_location=None,
52-
enable_network_isolation=False,
53+
role: str,
54+
instance_count: Optional[Union[int]] = None,
55+
instance_type: Optional[Union[str, PipelineVariable]] = None,
56+
data_location: Optional[str] = None,
57+
enable_network_isolation: Union[bool, ParameterBoolean] = False,
5358
**kwargs
5459
):
5560
"""Initialize an AmazonAlgorithmEstimatorBase.
@@ -113,6 +118,11 @@ def data_location(self):
113118
@data_location.setter
114119
def data_location(self, data_location):
115120
"""Placeholder docstring"""
121+
if is_pipeline_variable(data_location):
122+
raise ValueError(
123+
"data_location argument has to be an integer " + "rather than a pipeline variable"
124+
)
125+
116126
if not data_location.startswith("s3://"):
117127
raise ValueError(
118128
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
@@ -196,12 +206,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
196206
@runnable_by_pipeline
197207
def fit(
198208
self,
199-
records,
200-
mini_batch_size=None,
201-
wait=True,
202-
logs=True,
203-
job_name=None,
204-
experiment_config=None,
209+
records: "RecordSet",
210+
mini_batch_size: Optional[int] = None,
211+
wait: bool = True,
212+
logs: bool = True,
213+
job_name: Optional[str] = None,
214+
experiment_config: Optional[Dict[str, str]] = None,
205215
):
206216
"""Fit this Estimator on serialized Record objects, stored in S3.
207217

src/sagemaker/amazon/factorization_machines.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,83 +37,83 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase):
3737
sparse datasets economically.
3838
"""
3939

40-
repo_name = "factorization-machines"
41-
repo_version = 1
40+
repo_name: str = "factorization-machines"
41+
repo_version: int = 1
4242

43-
num_factors = hp("num_factors", gt(0), "An integer greater than zero", int)
44-
predictor_type = hp(
43+
num_factors: hp = hp("num_factors", gt(0), "An integer greater than zero", int)
44+
predictor_type: hp = hp(
4545
"predictor_type",
4646
isin("binary_classifier", "regressor"),
4747
'Value "binary_classifier" or "regressor"',
4848
str,
4949
)
50-
epochs = hp("epochs", gt(0), "An integer greater than 0", int)
51-
clip_gradient = hp("clip_gradient", (), "A float value", float)
52-
eps = hp("eps", (), "A float value", float)
53-
rescale_grad = hp("rescale_grad", (), "A float value", float)
54-
bias_lr = hp("bias_lr", ge(0), "A non-negative float", float)
55-
linear_lr = hp("linear_lr", ge(0), "A non-negative float", float)
56-
factors_lr = hp("factors_lr", ge(0), "A non-negative float", float)
57-
bias_wd = hp("bias_wd", ge(0), "A non-negative float", float)
58-
linear_wd = hp("linear_wd", ge(0), "A non-negative float", float)
59-
factors_wd = hp("factors_wd", ge(0), "A non-negative float", float)
60-
bias_init_method = hp(
50+
epochs: hp = hp("epochs", gt(0), "An integer greater than 0", int)
51+
clip_gradient: hp = hp("clip_gradient", (), "A float value", float)
52+
eps: hp = hp("eps", (), "A float value", float)
53+
rescale_grad: hp = hp("rescale_grad", (), "A float value", float)
54+
bias_lr: hp = hp("bias_lr", ge(0), "A non-negative float", float)
55+
linear_lr: hp = hp("linear_lr", ge(0), "A non-negative float", float)
56+
factors_lr: hp = hp("factors_lr", ge(0), "A non-negative float", float)
57+
bias_wd: hp = hp("bias_wd", ge(0), "A non-negative float", float)
58+
linear_wd: hp = hp("linear_wd", ge(0), "A non-negative float", float)
59+
factors_wd: hp = hp("factors_wd", ge(0), "A non-negative float", float)
60+
bias_init_method: hp = hp(
6161
"bias_init_method",
6262
isin("normal", "uniform", "constant"),
6363
'Value "normal", "uniform" or "constant"',
6464
str,
6565
)
66-
bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float)
67-
bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68-
bias_init_value = hp("bias_init_value", (), "A float value", float)
69-
linear_init_method = hp(
66+
bias_init_scale: hp = hp("bias_init_scale", ge(0), "A non-negative float", float)
67+
bias_init_sigma: hp = hp("bias_init_sigma", ge(0), "A non-negative float", float)
68+
bias_init_value: hp = hp("bias_init_value", (), "A float value", float)
69+
linear_init_method: hp = hp(
7070
"linear_init_method",
7171
isin("normal", "uniform", "constant"),
7272
'Value "normal", "uniform" or "constant"',
7373
str,
7474
)
75-
linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float)
76-
linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77-
linear_init_value = hp("linear_init_value", (), "A float value", float)
78-
factors_init_method = hp(
75+
linear_init_scale: hp = hp("linear_init_scale", ge(0), "A non-negative float", float)
76+
linear_init_sigma: hp = hp("linear_init_sigma", ge(0), "A non-negative float", float)
77+
linear_init_value: hp = hp("linear_init_value", (), "A float value", float)
78+
factors_init_method: hp = hp(
7979
"factors_init_method",
8080
isin("normal", "uniform", "constant"),
8181
'Value "normal", "uniform" or "constant"',
8282
str,
8383
)
84-
factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float)
85-
factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86-
factors_init_value = hp("factors_init_value", (), "A float value", float)
84+
factors_init_scale: hp = hp("factors_init_scale", ge(0), "A non-negative float", float)
85+
factors_init_sigma: hp = hp("factors_init_sigma", ge(0), "A non-negative float", float)
86+
factors_init_value: hp = hp("factors_init_value", (), "A float value", float)
8787

8888
def __init__(
8989
self,
90-
role,
91-
instance_count=None,
92-
instance_type=None,
93-
num_factors=None,
94-
predictor_type=None,
95-
epochs=None,
96-
clip_gradient=None,
97-
eps=None,
98-
rescale_grad=None,
99-
bias_lr=None,
100-
linear_lr=None,
101-
factors_lr=None,
102-
bias_wd=None,
103-
linear_wd=None,
104-
factors_wd=None,
105-
bias_init_method=None,
106-
bias_init_scale=None,
107-
bias_init_sigma=None,
108-
bias_init_value=None,
109-
linear_init_method=None,
110-
linear_init_scale=None,
111-
linear_init_sigma=None,
112-
linear_init_value=None,
113-
factors_init_method=None,
114-
factors_init_scale=None,
115-
factors_init_sigma=None,
116-
factors_init_value=None,
90+
role: str,
91+
instance_count: Optional[Union[int, PipelineVariable]] = None,
92+
instance_type: Optional[Union[str, PipelineVariable]] = None,
93+
num_factors: Optional[int] = None,
94+
predictor_type: Optional[str] = None,
95+
epochs: Optional[int] = None,
96+
clip_gradient: Optional[float] = None,
97+
eps: Optional[float] = None,
98+
rescale_grad: Optional[float] = None,
99+
bias_lr: Optional[float] = None,
100+
linear_lr: Optional[float] = None,
101+
factors_lr: Optional[float] = None,
102+
bias_wd: Optional[float] = None,
103+
linear_wd: Optional[float] = None,
104+
factors_wd: Optional[float] = None,
105+
bias_init_method: Optional[str] = None,
106+
bias_init_scale: Optional[float] = None,
107+
bias_init_sigma: Optional[float] = None,
108+
bias_init_value: Optional[float] = None,
109+
linear_init_method: Optional[str] = None,
110+
linear_init_scale: Optional[float] = None,
111+
linear_init_sigma: Optional[float] = None,
112+
linear_init_value: Optional[float] = None,
113+
factors_init_method: Optional[str] = None,
114+
factors_init_scale: Optional[float] = None,
115+
factors_init_sigma: Optional[float] = None,
116+
factors_init_value: Optional[float] = None,
117117
**kwargs
118118
):
119119
"""Factorization Machines is :class:`Estimator` for general-purpose supervised learning.

src/sagemaker/amazon/hyperparameter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import json
17+
from sagemaker.workflow import is_pipeline_variable
1718

1819

1920
class Hyperparameter(object):
@@ -98,8 +99,14 @@ def serialize_all(obj):
9899
"""
99100
if "_hyperparameters" not in dir(obj):
100101
return {}
101-
return {
102-
k: json.dumps(v) if isinstance(v, list) else str(v)
103-
for k, v in obj._hyperparameters.items()
104-
if v is not None
105-
}
102+
hps = {}
103+
for k, v in obj._hyperparameters.items():
104+
if v is not None:
105+
if isinstance(v, list):
106+
v = json.dumps(v)
107+
elif is_pipeline_variable(v):
108+
v = v.to_string()
109+
else:
110+
v = str(v)
111+
hps[k] = v
112+
return hps

0 commit comments

Comments
 (0)