Skip to content

Commit f9f02d3

Browse files
author
Dewen Qi
committed
change: Implement test mechanism for Pipeline variables
1 parent 67de784 commit f9f02d3

File tree

13 files changed

+2024
-11
lines changed

13 files changed

+2024
-11
lines changed

src/sagemaker/debugger/debugger.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
import time
2424

2525
from abc import ABC
26+
from typing import Union, Optional
2627

2728
import attr
2829

2930
import smdebug_rulesconfig as rule_configs
3031

3132
from sagemaker import image_uris
3233
from sagemaker.utils import build_dict
34+
from sagemaker.workflow.entities import PipelineVariable
3335

3436
framework_name = "debugger"
3537
DEBUGGER_FLAG = "USE_SMDEBUG"
@@ -311,10 +313,10 @@ def sagemaker(
311313
@classmethod
312314
def custom(
313315
cls,
314-
name,
315-
image_uri,
316-
instance_type,
317-
volume_size_in_gb,
316+
name: str,
317+
image_uri: Union[str, PipelineVariable],
318+
instance_type: Union[str, PipelineVariable],
319+
volume_size_in_gb: Union[int, PipelineVariable],
318320
source=None,
319321
rule_to_invoke=None,
320322
container_local_output_path=None,
@@ -610,7 +612,7 @@ class DebuggerHookConfig(object):
610612

611613
def __init__(
612614
self,
613-
s3_output_path=None,
615+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
614616
container_local_output_path=None,
615617
hook_parameters=None,
616618
collection_configs=None,
@@ -679,7 +681,9 @@ def _to_request_dict(self):
679681
class TensorBoardOutputConfig(object):
680682
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
681683

682-
def __init__(self, s3_output_path, container_local_output_path=None):
684+
def __init__(
685+
self, s3_output_path: Union[str, PipelineVariable], container_local_output_path=None
686+
):
683687
"""Initialize the TensorBoardOutputConfig instance.
684688
685689
Args:

src/sagemaker/debugger/profiler_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker.debugger.framework_profile import FrameworkProfile
19+
from sagemaker.workflow.entities import PipelineVariable
1720

1821

1922
class ProfilerConfig(object):
@@ -27,7 +30,7 @@ class ProfilerConfig(object):
2730
def __init__(
2831
self,
2932
s3_output_path=None,
30-
system_monitor_interval_millis=None,
33+
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
3134
framework_profile_params=None,
3235
):
3336
"""Initialize a ``ProfilerConfig`` instance.

src/sagemaker/inputs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16+
from typing import Union, Optional
17+
1618
import attr
1719

20+
from sagemaker.workflow.entities import PipelineVariable
21+
1822
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
1923
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2024

@@ -29,10 +33,10 @@ class TrainingInput(object):
2933

3034
def __init__(
3135
self,
32-
s3_data,
36+
s3_data: Union[str, PipelineVariable],
3337
distribution=None,
3438
compression=None,
35-
content_type=None,
39+
content_type: Optional[Union[str, PipelineVariable]] = None,
3640
record_wrapping=None,
3741
s3_data_type="S3Prefix",
3842
input_mode=None,

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
"""
1818
from __future__ import print_function, absolute_import
1919

20+
from typing import Union
21+
22+
from sagemaker.workflow.entities import PipelineVariable
23+
2024

2125
class ServerlessInferenceConfig(object):
2226
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.

src/sagemaker/tensorflow/training_compiler/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
"""Configuration for the SageMaker Training Compiler."""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Union
17+
1618
from packaging.specifiers import SpecifierSet
1719
from packaging.version import Version
1820

1921
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
22+
from sagemaker.workflow.entities import PipelineVariable
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -29,8 +32,8 @@ class TrainingCompilerConfig(BaseConfig):
2932

3033
def __init__(
3134
self,
32-
enabled=True,
33-
debug=False,
35+
enabled: Union[bool, PipelineVariable] = True,
36+
debug: Union[bool, PipelineVariable] = False,
3437
):
3538
"""This class initializes a ``TrainingCompilerConfig`` instance.
3639

0 commit comments

Comments
 (0)