Skip to content

Commit ee23b12

Browse files
author
Dewen Qi
committed
change: Add Pipeline annotation in model base class and tensorflow estimator
Model annotate update
1 parent 341f345 commit ee23b12

File tree

9 files changed

+128
-96
lines changed

9 files changed

+128
-96
lines changed

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
109109
def __init__(
110110
self,
111111
role: str,
112-
instance_count: Optional[int] = None,
112+
instance_count: Optional[Union[int, PipelineVariable]] = None,
113113
instance_type: Optional[Union[str, PipelineVariable]] = None,
114114
volume_size: Union[int, PipelineVariable] = 30,
115115
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
@@ -2743,15 +2743,24 @@ def _validate_and_set_debugger_configs(self):
27432743
# Disable debugger if checkpointing is enabled by the customer
27442744
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
27452745
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2746-
if self.instance_count > 1 or (
2747-
hasattr(self, "distribution")
2748-
and self.distribution is not None # pylint: disable=no-member
2749-
):
2746+
disable_debugger_hook_config = False
2747+
if not is_pipeline_variable(self.instance_count) and self.instance_count > 1:
2748+
disable_debugger_hook_config = True
2749+
if (
2750+
hasattr(self, "distribution") and self.distribution is not None
2751+
): # pylint: disable=no-member
2752+
disable_debugger_hook_config = True
2753+
if disable_debugger_hook_config:
27502754
logger.info(
27512755
"SMDebug Does Not Currently Support \
27522756
Distributed Training Jobs With Checkpointing Enabled"
27532757
)
27542758
self.debugger_hook_config = False
2759+
elif is_pipeline_variable(self.instance_count):
2760+
logger.warning(
2761+
"SMDebug does not currently support when parameterized "
2762+
"instance_count with value > 1 in pipeline execution."
2763+
)
27552764

27562765
if self.debugger_hook_config is False:
27572766
if self.environment is None:

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

src/sagemaker/model.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import re
2121
import copy
22-
from typing import List, Dict
22+
from typing import List, Dict, Optional, Union
2323

2424
import sagemaker
2525
from sagemaker import (
@@ -30,15 +30,20 @@
3030
utils,
3131
git_utils,
3232
)
33+
from sagemaker.session import Session
34+
from sagemaker.model_metrics import ModelMetrics
3335
from sagemaker.deprecations import removed_kwargs
36+
from sagemaker.drift_check_baselines import DriftCheckBaselines
37+
from sagemaker.metadata_properties import MetadataProperties
3438
from sagemaker.predictor import PredictorBase
3539
from sagemaker.serverless import ServerlessInferenceConfig
3640
from sagemaker.transformer import Transformer
3741
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
38-
from sagemaker.utils import unique_name_from_base
42+
from sagemaker.utils import unique_name_from_base, to_string
3943
from sagemaker.async_inference import AsyncInferenceConfig
4044
from sagemaker.predictor_async import AsyncPredictor
4145
from sagemaker.workflow import is_pipeline_variable
46+
from sagemaker.workflow.entities import PipelineVariable
4247
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4348

4449
LOGGER = logging.getLogger("sagemaker")
@@ -78,23 +83,23 @@ class Model(ModelBase):
7883

7984
def __init__(
8085
self,
81-
image_uri,
82-
model_data=None,
83-
role=None,
84-
predictor_cls=None,
85-
env=None,
86-
name=None,
87-
vpc_config=None,
88-
sagemaker_session=None,
89-
enable_network_isolation=False,
90-
model_kms_key=None,
91-
image_config=None,
92-
source_dir=None,
93-
code_location=None,
94-
entry_point=None,
95-
container_log_level=logging.INFO,
96-
dependencies=None,
97-
git_config=None,
86+
image_uri: Union[str, PipelineVariable],
87+
model_data: Optional[Union[str, PipelineVariable]] = None,
88+
role: Optional[str] = None,
89+
predictor_cls: Optional[callable] = None,
90+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
91+
name: Optional[str] = None,
92+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
93+
sagemaker_session: Optional[Session] = None,
94+
enable_network_isolation: Union[bool, PipelineVariable] = False,
95+
model_kms_key: Optional[str] = None,
96+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
97+
source_dir: Optional[str] = None,
98+
code_location: Optional[str] = None,
99+
entry_point: Optional[str] = None,
100+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
101+
dependencies: Optional[List[str]] = None,
102+
git_config: Optional[Dict[str, str]] = None,
98103
):
99104
"""Initialize an SageMaker ``Model``.
100105
@@ -294,22 +299,22 @@ def __init__(
294299
@runnable_by_pipeline
295300
def register(
296301
self,
297-
content_types,
298-
response_types,
299-
inference_instances=None,
300-
transform_instances=None,
301-
model_package_name=None,
302-
model_package_group_name=None,
303-
image_uri=None,
304-
model_metrics=None,
305-
metadata_properties=None,
306-
marketplace_cert=False,
307-
approval_status=None,
308-
description=None,
309-
drift_check_baselines=None,
310-
customer_metadata_properties=None,
311-
validation_specification=None,
312-
domain=None,
302+
content_types: List[Union[str, PipelineVariable]],
303+
response_types: List[Union[str, PipelineVariable]],
304+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
305+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
306+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
307+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
308+
image_uri: Optional[Union[str, PipelineVariable]] = None,
309+
model_metrics: Optional[ModelMetrics] = None,
310+
metadata_properties: Optional[MetadataProperties] = None,
311+
marketplace_cert: bool = False,
312+
approval_status: Optional[Union[str, PipelineVariable]] = None,
313+
description: Optional[str] = None,
314+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
315+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
316+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
317+
domain: Optional[Union[str, PipelineVariable]] = None,
313318
):
314319
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315320
@@ -385,10 +390,10 @@ def register(
385390
@runnable_by_pipeline
386391
def create(
387392
self,
388-
instance_type: str = None,
389-
accelerator_type: str = None,
390-
serverless_inference_config: ServerlessInferenceConfig = None,
391-
tags: List[Dict[str, str]] = None,
393+
instance_type: Optional[str] = None,
394+
accelerator_type: Optional[str] = None,
395+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
396+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
392397
):
393398
"""Create a SageMaker Model Entity
394399
@@ -570,7 +575,7 @@ def _script_mode_env_vars(self):
570575
return {
571576
SCRIPT_PARAM_NAME.upper(): script_name or str(),
572577
DIR_PARAM_NAME.upper(): dir_name or str(),
573-
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
578+
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
574579
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
575580
}
576581

@@ -1239,19 +1244,19 @@ class FrameworkModel(Model):
12391244

12401245
def __init__(
12411246
self,
1242-
model_data,
1243-
image_uri,
1244-
role,
1245-
entry_point,
1246-
source_dir=None,
1247-
predictor_cls=None,
1248-
env=None,
1249-
name=None,
1250-
container_log_level=logging.INFO,
1251-
code_location=None,
1252-
sagemaker_session=None,
1253-
dependencies=None,
1254-
git_config=None,
1247+
model_data: Union[str, PipelineVariable],
1248+
image_uri: Union[str, PipelineVariable],
1249+
role: str,
1250+
entry_point: str,
1251+
source_dir: Optional[str] = None,
1252+
predictor_cls: Optional[callable] = None,
1253+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
1254+
name: Optional[str] = None,
1255+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
1256+
code_location: Optional[str] = None,
1257+
sagemaker_session: Optional[Session] = None,
1258+
dependencies: Optional[List[str]] = None,
1259+
git_config: Optional[Dict[str, str]] = None,
12551260
**kwargs,
12561261
):
12571262
"""Initialize a ``FrameworkModel``.

src/sagemaker/model_metrics.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313
"""This file contains code related to model metrics, including metric source and file source."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class ModelMetrics(object):
1822
"""Accepts model metrics parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias=None,
27-
explainability=None,
28-
bias_pre_training=None,
29-
bias_post_training=None,
26+
model_statistics: Optional["MetricsSource"] = None,
27+
model_constraints: Optional["MetricsSource"] = None,
28+
model_data_statistics: Optional["MetricsSource"] = None,
29+
model_data_constraints: Optional["MetricsSource"] = None,
30+
bias: Optional["MetricsSource"] = None,
31+
explainability: Optional["MetricsSource"] = None,
32+
bias_pre_training: Optional["MetricsSource"] = None,
33+
bias_post_training: Optional["MetricsSource"] = None,
3034
):
3135
"""Initialize a ``ModelMetrics`` instance and turn parameters into dict.
3236
@@ -99,9 +103,9 @@ class MetricsSource(object):
99103

100104
def __init__(
101105
self,
102-
content_type,
103-
s3_uri,
104-
content_digest=None,
106+
content_type: Union[str, PipelineVariable],
107+
s3_uri: Union[str, PipelineVariable],
108+
content_digest: Optional[Union[str, PipelineVariable]] = None,
105109
):
106110
"""Initialize a ``MetricsSource`` instance and turn parameters into dict.
107111
@@ -127,9 +131,9 @@ class FileSource(object):
127131

128132
def __init__(
129133
self,
130-
s3_uri,
131-
content_digest=None,
132-
content_type=None,
134+
s3_uri: Union[str, PipelineVariable],
135+
content_digest: Optional[Union[str, PipelineVariable]] = None,
136+
content_type: Optional[Union[str, PipelineVariable]] = None,
133137
):
134138
"""Initialize a ``FileSource`` instance and turn parameters into dict.
135139

src/sagemaker/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def deploy(
229229
return None
230230

231231
@runnable_by_pipeline
232-
def create(self, instance_type: Union[str, PipelineVariable]):
232+
def create(self, instance_type: str):
233233
"""Create a SageMaker Model Entity
234234
235235
Args:

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object):
2727

2828
def __init__(
2929
self,
30-
memory_size_in_mb=2048,
31-
max_concurrency=5,
30+
memory_size_in_mb: int = 2048,
31+
max_concurrency: int = 5,
3232
):
3333
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
3434

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,9 @@ def _create_model_request(
26332633
request["VpcConfig"] = vpc_config
26342634

26352635
if enable_network_isolation:
2636-
request["EnableNetworkIsolation"] = True
2636+
# enable_network_isolation may be a pipeline variable which is
2637+
# parsed in execution time
2638+
request["EnableNetworkIsolation"] = enable_network_isolation
26372639

26382640
return request
26392641

0 commit comments

Comments
 (0)