Skip to content

Commit 0acb88c

Browse files
author
Dewen Qi
committed
change: Add PipelineVariable annotation to composite argument of training
1 parent ee23b12 commit 0acb88c

File tree

6 files changed

+58
-36
lines changed

6 files changed

+58
-36
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ def fit(
243243
* `TrialComponentDisplayName` is used for display in Studio.
244244
"""
245245
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
246-
247246
self.latest_training_job = _TrainingJob.start_new(
248247
self, records, experiment_config=experiment_config
249248
)
@@ -304,7 +303,12 @@ class RecordSet(object):
304303
"""Placeholder docstring"""
305304

306305
def __init__(
307-
self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train"
306+
self,
307+
s3_data: Union[str, PipelineVariable],
308+
num_records: int,
309+
feature_dim: int,
310+
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
311+
channel: Union[str, PipelineVariable] = "train",
308312
):
309313
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
310314

src/sagemaker/debugger/debugger.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
from abc import ABC
2626

27+
from sagemaker.workflow.entities import PipelineVariable
28+
from typing import Union, Optional, List, Dict
29+
2730
import attr
2831

2932
import smdebug_rulesconfig as rule_configs
@@ -311,17 +314,17 @@ def sagemaker(
311314
@classmethod
312315
def custom(
313316
cls,
314-
name,
315-
image_uri,
316-
instance_type,
317-
volume_size_in_gb,
318-
source=None,
319-
rule_to_invoke=None,
320-
container_local_output_path=None,
321-
s3_output_path=None,
322-
other_trials_s3_input_paths=None,
323-
rule_parameters=None,
324-
collections_to_save=None,
317+
name: str,
318+
image_uri: Union[str, PipelineVariable],
319+
instance_type: Union[str, PipelineVariable],
320+
volume_size_in_gb: Union[int, PipelineVariable],
321+
source: Optional[str] = None,
322+
rule_to_invoke: Optional[Union[str, PipelineVariable]] = None,
323+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
324+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
325+
other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None,
326+
rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
327+
collections_to_save: Optional[List["CollectionConfig"]] = None,
325328
actions=None,
326329
):
327330
"""Initialize a ``Rule`` object for a *custom* debugging rule.
@@ -610,10 +613,10 @@ class DebuggerHookConfig(object):
610613

611614
def __init__(
612615
self,
613-
s3_output_path=None,
614-
container_local_output_path=None,
615-
hook_parameters=None,
616-
collection_configs=None,
616+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
617+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
618+
hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
619+
collection_configs: Optional[List["CollectionConfig"]] = None,
617620
):
618621
"""Initialize the DebuggerHookConfig instance.
619622
@@ -679,7 +682,11 @@ def _to_request_dict(self):
679682
class TensorBoardOutputConfig(object):
680683
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
681684

682-
def __init__(self, s3_output_path, container_local_output_path=None):
685+
def __init__(
686+
self,
687+
s3_output_path: Union[str, PipelineVariable],
688+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
689+
):
683690
"""Initialize the TensorBoardOutputConfig instance.
684691
685692
Args:
@@ -708,7 +715,11 @@ def _to_request_dict(self):
708715
class CollectionConfig(object):
709716
"""Creates tensor collections for SageMaker Debugger."""
710717

711-
def __init__(self, name, parameters=None):
718+
def __init__(
719+
self,
720+
name: Union[str, PipelineVariable],
721+
parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
722+
):
712723
"""Constructor for collection configuration.
713724
714725
Args:

src/sagemaker/debugger/profiler_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""Configuration for collecting system and framework metrics in SageMaker training jobs."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
1618
from sagemaker.debugger.framework_profile import FrameworkProfile
19+
from sagemaker.workflow.entities import PipelineVariable
1720

1821

1922
class ProfilerConfig(object):
@@ -26,9 +29,9 @@ class ProfilerConfig(object):
2629

2730
def __init__(
2831
self,
29-
s3_output_path=None,
30-
system_monitor_interval_millis=None,
31-
framework_profile_params=None,
32+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
33+
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
34+
framework_profile_params: Optional[FrameworkProfile] = None,
3235
):
3336
"""Initialize a ``ProfilerConfig`` instance.
3437

src/sagemaker/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2747,8 +2747,9 @@ def _validate_and_set_debugger_configs(self):
27472747
if not is_pipeline_variable(self.instance_count) and self.instance_count > 1:
27482748
disable_debugger_hook_config = True
27492749
if (
2750-
hasattr(self, "distribution") and self.distribution is not None
2751-
): # pylint: disable=no-member
2750+
hasattr(self, "distribution")
2751+
and self.distribution is not None # pylint: disable=no-member
2752+
):
27522753
disable_debugger_hook_config = True
27532754
if disable_debugger_hook_config:
27542755
logger.info(

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Configuration for the SageMaker Training Compiler."""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Union
1617

1718
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
19+
from sagemaker.workflow.entities import PipelineVariable
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig):
2628

2729
def __init__(
2830
self,
29-
enabled=True,
30-
debug=False,
31+
enabled: Union[bool, PipelineVariable] = True,
32+
debug: Union[bool, PipelineVariable] = False,
3133
):
3234
"""This class initializes a ``TrainingCompilerConfig`` instance.
3335

src/sagemaker/inputs.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16+
from typing import Union, Optional, List
1617
import attr
1718

1819
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
@@ -29,16 +30,16 @@ class TrainingInput(object):
2930

3031
def __init__(
3132
self,
32-
s3_data,
33-
distribution=None,
34-
compression=None,
35-
content_type=None,
36-
record_wrapping=None,
37-
s3_data_type="S3Prefix",
38-
input_mode=None,
39-
attribute_names=None,
40-
target_attribute_name=None,
41-
shuffle_config=None,
33+
s3_data: Union[str, PipelineVariable],
34+
distribution: Optional[Union[str, PipelineVariable]] = None,
35+
compression: Optional[Union[str, PipelineVariable]] = None,
36+
content_type: Optional[Union[str, PipelineVariable]] = None,
37+
record_wrapping: Optional[Union[str, PipelineVariable]] = None,
38+
s3_data_type: Union[str, PipelineVariable] = "S3Prefix",
39+
input_mode: Optional[Union[str, PipelineVariable]] = None,
40+
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
41+
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
42+
shuffle_config: Optional["ShuffleConfig"] = None,
4243
):
4344
"""Create a definition for input data used by an SageMaker training job.
4445

0 commit comments

Comments
 (0)