Skip to content

Commit 3aa6ebf

Browse files
author
Dewen Qi
committed
change: Implement test mechanism for Pipeline variables
change: test mechanism improvement tmp
1 parent 325214f commit 3aa6ebf

25 files changed

+2009
-369
lines changed

src/sagemaker/debugger/debugger.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
import time
2424

2525
from abc import ABC
26+
from typing import Union, Optional
2627

2728
import attr
2829

2930
import smdebug_rulesconfig as rule_configs
3031

3132
from sagemaker import image_uris
3233
from sagemaker.utils import build_dict
34+
from sagemaker.workflow.entities import PipelineVariable
3335

3436
framework_name = "debugger"
3537
DEBUGGER_FLAG = "USE_SMDEBUG"
@@ -311,10 +313,10 @@ def sagemaker(
311313
@classmethod
312314
def custom(
313315
cls,
314-
name,
315-
image_uri,
316-
instance_type,
317-
volume_size_in_gb,
316+
name: str,
317+
image_uri: Union[str, PipelineVariable],
318+
instance_type: Union[str, PipelineVariable],
319+
volume_size_in_gb: Union[int, PipelineVariable],
318320
source=None,
319321
rule_to_invoke=None,
320322
container_local_output_path=None,
@@ -610,7 +612,7 @@ class DebuggerHookConfig(object):
610612

611613
def __init__(
612614
self,
613-
s3_output_path=None,
615+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
614616
container_local_output_path=None,
615617
hook_parameters=None,
616618
collection_configs=None,
@@ -679,7 +681,9 @@ def _to_request_dict(self):
679681
class TensorBoardOutputConfig(object):
680682
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
681683

682-
def __init__(self, s3_output_path, container_local_output_path=None):
684+
def __init__(
685+
self, s3_output_path: Union[str, PipelineVariable], container_local_output_path=None
686+
):
683687
"""Initialize the TensorBoardOutputConfig instance.
684688
685689
Args:

src/sagemaker/debugger/profiler_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker.debugger.framework_profile import FrameworkProfile
19+
from sagemaker.workflow.entities import PipelineVariable
1720

1821

1922
class ProfilerConfig(object):
@@ -27,7 +30,7 @@ class ProfilerConfig(object):
2730
def __init__(
2831
self,
2932
s3_output_path=None,
30-
system_monitor_interval_millis=None,
33+
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
3134
framework_profile_params=None,
3235
):
3336
"""Initialize a ``ProfilerConfig`` instance.

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 109 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import uuid
2020
from abc import ABCMeta, abstractmethod
21-
from typing import Any, Dict
21+
from typing import Any, Dict, Union, Optional, List
2222

2323
from six import string_types, with_metaclass
2424
from six.moves.urllib.parse import urlparse
@@ -36,6 +36,7 @@
3636
TensorBoardOutputConfig,
3737
get_default_profiler_rule,
3838
get_rule_container_image_uri,
39+
RuleBase,
3940
)
4041
from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs
4142
from sagemaker.fw_utils import (
@@ -46,7 +47,7 @@
4647
tar_and_upload_dir,
4748
validate_source_dir,
4849
)
49-
from sagemaker.inputs import TrainingInput
50+
from sagemaker.inputs import TrainingInput, FileSystemInput
5051
from sagemaker.job import _Job
5152
from sagemaker.jumpstart.utils import (
5253
add_jumpstart_tags,
@@ -75,6 +76,7 @@
7576
name_from_base,
7677
)
7778
from sagemaker.workflow import is_pipeline_variable
79+
from sagemaker.workflow.entities import PipelineVariable
7880
from sagemaker.workflow.pipeline_context import (
7981
PipelineSession,
8082
runnable_by_pipeline,
@@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
105107

106108
def __init__(
107109
self,
108-
role,
109-
instance_count=None,
110-
instance_type=None,
111-
volume_size=30,
112-
volume_kms_key=None,
113-
max_run=24 * 60 * 60,
114-
input_mode="File",
115-
output_path=None,
116-
output_kms_key=None,
117-
base_job_name=None,
118-
sagemaker_session=None,
119-
tags=None,
120-
subnets=None,
121-
security_group_ids=None,
122-
model_uri=None,
123-
model_channel_name="model",
124-
metric_definitions=None,
125-
encrypt_inter_container_traffic=False,
126-
use_spot_instances=False,
127-
max_wait=None,
128-
checkpoint_s3_uri=None,
129-
checkpoint_local_path=None,
130-
rules=None,
131-
debugger_hook_config=None,
132-
tensorboard_output_config=None,
133-
enable_sagemaker_metrics=None,
134-
enable_network_isolation=False,
135-
profiler_config=None,
136-
disable_profiler=False,
137-
environment=None,
138-
max_retry_attempts=None,
139-
source_dir=None,
140-
git_config=None,
141-
hyperparameters=None,
142-
container_log_level=logging.INFO,
143-
code_location=None,
144-
entry_point=None,
145-
dependencies=None,
110+
role: str,
111+
instance_count: Optional[int] = None,
112+
instance_type: Optional[Union[str, PipelineVariable]] = None,
113+
volume_size: Union[int, PipelineVariable] = 30,
114+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
115+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
116+
input_mode: Union[str, PipelineVariable] = "File",
117+
output_path: Optional[Union[str, PipelineVariable]] = None,
118+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
119+
base_job_name: Optional[str] = None,
120+
sagemaker_session: Optional[Session] = None,
121+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
122+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
123+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
124+
model_uri: Optional[str] = None,
125+
model_channel_name: Union[str, PipelineVariable] = "model",
126+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
127+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
128+
use_spot_instances: Union[bool, PipelineVariable] = False,
129+
max_wait: Union[int, PipelineVariable] = None,
130+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
131+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
132+
rules: Optional[List[RuleBase]] = None,
133+
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
134+
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
135+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
136+
enable_network_isolation: Union[bool, PipelineVariable] = False,
137+
profiler_config: Optional[ProfilerConfig] = None,
138+
disable_profiler: bool = False,
139+
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
140+
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
141+
source_dir: Optional[str] = None,
142+
git_config: Optional[Dict[str, str]] = None,
143+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
144+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
145+
code_location: Optional[str] = None,
146+
entry_point: Optional[str] = None,
147+
dependencies: Optional[List[Union[str]]] = None,
146148
**kwargs,
147149
):
148150
"""Initialize an ``EstimatorBase`` instance.
@@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self):
922924
return None
923925

924926
@runnable_by_pipeline
925-
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
927+
def fit(
928+
self,
929+
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
930+
wait: bool = True,
931+
logs: str = "All",
932+
job_name: Optional[str] = None,
933+
experiment_config: Optional[Dict[str, str]] = None,
934+
):
926935
"""Train a model using the input training dataset.
927936
928937
The API calls the Amazon SageMaker CreateTrainingJob API to start
@@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18701879
)
18711880
train_args["input_mode"] = inputs.config["InputMode"]
18721881

1882+
# enable_network_isolation may be a pipeline variable place holder object
1883+
# which is parsed in execution time
18731884
if estimator.enable_network_isolation():
1874-
train_args["enable_network_isolation"] = True
1885+
train_args["enable_network_isolation"] = estimator.enable_network_isolation()
18751886

18761887
if estimator.max_retry_attempts is not None:
18771888
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
18781889
else:
18791890
train_args["retry_strategy"] = None
18801891

1892+
# encrypt_inter_container_traffic may be a pipeline variable place holder object
1893+
# which is parsed in execution time
18811894
if estimator.encrypt_inter_container_traffic:
1882-
train_args["encrypt_inter_container_traffic"] = True
1895+
train_args[
1896+
"encrypt_inter_container_traffic"
1897+
] = estimator.encrypt_inter_container_traffic
18831898

18841899
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
18851900
train_args["algorithm_arn"] = estimator.algorithm_arn
@@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase):
20252040

20262041
def __init__(
20272042
self,
2028-
image_uri,
2029-
role,
2030-
instance_count=None,
2031-
instance_type=None,
2032-
volume_size=30,
2033-
volume_kms_key=None,
2034-
max_run=24 * 60 * 60,
2035-
input_mode="File",
2036-
output_path=None,
2037-
output_kms_key=None,
2038-
base_job_name=None,
2039-
sagemaker_session=None,
2040-
hyperparameters=None,
2041-
tags=None,
2042-
subnets=None,
2043-
security_group_ids=None,
2044-
model_uri=None,
2045-
model_channel_name="model",
2046-
metric_definitions=None,
2047-
encrypt_inter_container_traffic=False,
2048-
use_spot_instances=False,
2049-
max_wait=None,
2050-
checkpoint_s3_uri=None,
2051-
checkpoint_local_path=None,
2052-
enable_network_isolation=False,
2053-
rules=None,
2054-
debugger_hook_config=None,
2055-
tensorboard_output_config=None,
2056-
enable_sagemaker_metrics=None,
2057-
profiler_config=None,
2058-
disable_profiler=False,
2059-
environment=None,
2060-
max_retry_attempts=None,
2061-
source_dir=None,
2062-
git_config=None,
2063-
container_log_level=logging.INFO,
2064-
code_location=None,
2065-
entry_point=None,
2066-
dependencies=None,
2043+
image_uri: Union[str, PipelineVariable],
2044+
role: str,
2045+
instance_count: Optional[Union[int, PipelineVariable]] = None,
2046+
instance_type: Optional[Union[str, PipelineVariable]] = None,
2047+
volume_size: Union[int, PipelineVariable] = 30,
2048+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
2049+
max_run: Union[int, PipelineVariable] = 24 * 60 * 60,
2050+
input_mode: Union[str, PipelineVariable] = "File",
2051+
output_path: Optional[Union[str, PipelineVariable]] = None,
2052+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
2053+
base_job_name: Optional[str] = None,
2054+
sagemaker_session: Optional[Session] = None,
2055+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2056+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2057+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
2058+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
2059+
model_uri: Optional[str] = None,
2060+
model_channel_name: Union[str, PipelineVariable] = "model",
2061+
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2062+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
2063+
use_spot_instances: Union[bool, PipelineVariable] = False,
2064+
max_wait: Union[int, PipelineVariable] = None,
2065+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
2066+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
2067+
enable_network_isolation: Union[bool, PipelineVariable] = False,
2068+
rules: Optional[List[RuleBase]] = None,
2069+
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
2070+
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
2071+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
2072+
profiler_config: Optional[ProfilerConfig] = None,
2073+
disable_profiler: bool = False,
2074+
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2075+
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
2076+
source_dir: Optional[str] = None,
2077+
git_config: Optional[Dict[str, str]] = None,
2078+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
2079+
code_location: Optional[str] = None,
2080+
entry_point: Optional[str] = None,
2081+
dependencies: Optional[List[str]] = None,
20672082
**kwargs,
20682083
):
20692084
"""Initialize an ``Estimator`` instance.
@@ -2488,18 +2503,18 @@ class Framework(EstimatorBase):
24882503

24892504
def __init__(
24902505
self,
2491-
entry_point,
2492-
source_dir=None,
2493-
hyperparameters=None,
2494-
container_log_level=logging.INFO,
2495-
code_location=None,
2496-
image_uri=None,
2497-
dependencies=None,
2498-
enable_network_isolation=False,
2499-
git_config=None,
2500-
checkpoint_s3_uri=None,
2501-
checkpoint_local_path=None,
2502-
enable_sagemaker_metrics=None,
2506+
entry_point: str,
2507+
source_dir: Optional[str] = None,
2508+
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2509+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
2510+
code_location: Optional[str] = None,
2511+
image_uri: Optional[Union[str, PipelineVariable]] = None,
2512+
dependencies: Optional[List[str]] = None,
2513+
enable_network_isolation: Union[bool, PipelineVariable] = False,
2514+
git_config: Optional[Dict[str, str]] = None,
2515+
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
2516+
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
2517+
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
25032518
**kwargs,
25042519
):
25052520
"""Base class initializer.

0 commit comments

Comments
 (0)