Skip to content

Commit 4dd85da

Browse files
author
Dewen Qi
committed
change: Implement test mechanism for Pipeline variables
annotations for processors + estimators / test mechanism change annotations for processors + estimators / test mechanism change remove debug print reformatting update TM and resolve all AIs and all untested subclasses Add ppl var annotation to all composite object for training
1 parent 839dc31 commit 4dd85da

23 files changed

+3437
-39
lines changed

src/sagemaker/amazon/kmeans.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828
from sagemaker.workflow.entities import PipelineVariable
2929

30-
from sagemaker.workflow.entities import PipelineVariable
31-
3230

3331
class KMeans(AmazonAlgorithmEstimatorBase):
3432
"""An unsupervised learning algorithm that attempts to find discrete groupings within data.
@@ -76,7 +74,7 @@ def __init__(
7674
half_life_time_size: Optional[int] = None,
7775
epochs: Optional[int] = None,
7876
center_factor: Optional[int] = None,
79-
eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None,
77+
eval_metrics: Optional[List[str]] = None,
8078
**kwargs
8179
):
8280
"""A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`.

src/sagemaker/amazon/linear_learner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from typing import Union, Optional
1718

1819
from sagemaker import image_uris
@@ -28,6 +29,8 @@
2829
from sagemaker.workflow.entities import PipelineVariable
2930
from sagemaker.workflow import is_pipeline_variable
3031

32+
logger = logging.getLogger(__name__)
33+
3134

3235
class LinearLearner(AmazonAlgorithmEstimatorBase):
3336
"""A supervised learning algorithms used for solving classification or regression problems.
@@ -437,9 +440,14 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
437440
# mini_batch_size can't be greater than number of records or training job fails
438441
if not mini_batch_size:
439442
if is_pipeline_variable(self.instance_count):
440-
raise ValueError(
441-
"instance_count can not be a pipeline variable when mini_batch_size is not given."
443+
logger.warning(
444+
"mini_batch_size is not given in .fit() and instance_count is a "
445+
"pipeline variable (%s) which is only parsed in execution time. "
446+
"Thus setting mini_batch_size to 1, as it can't be greater than "
447+
"number of records per instance_count, otherwise the training job fails.",
448+
type(self.instance_count),
442449
)
450+
mini_batch_size = 1
443451
else:
444452
mini_batch_size = min(
445453
self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count))

src/sagemaker/debugger/debugger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from abc import ABC
2626

27-
from sagemaker.workflow.entities import PipelineVariable
2827
from typing import Union, Optional, List, Dict
2928

3029
import attr
@@ -33,6 +32,7 @@
3332

3433
from sagemaker import image_uris
3534
from sagemaker.utils import build_dict
35+
from sagemaker.workflow.entities import PipelineVariable
3636

3737
framework_name = "debugger"
3838
DEBUGGER_FLAG = "USE_SMDEBUG"

src/sagemaker/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
get_jumpstart_base_name_if_jumpstart_model,
5555
update_inference_tags_with_jumpstart_training_tags,
5656
)
57-
from sagemaker.debugger import RuleBase
5857
from sagemaker.local import LocalSession
5958
from sagemaker.model import (
6059
CONTAINER_LOG_LEVEL_PARAM_NAME,
@@ -2503,7 +2502,7 @@ class Framework(EstimatorBase):
25032502
def __init__(
25042503
self,
25052504
entry_point: str,
2506-
source_dir: Optional[Union[str, PipelineVariable]]= None,
2505+
source_dir: Optional[Union[str, PipelineVariable]] = None,
25072506
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
25082507
container_log_level: Union[int, PipelineVariable] = logging.INFO,
25092508
code_location: Optional[str] = None,

src/sagemaker/huggingface/estimator.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.huggingface.model import HuggingFaceModel
2929
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3030

31-
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # TODO: add annotations
31+
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
3232
from sagemaker.workflow.entities import PipelineVariable
3333

3434
logger = logging.getLogger("sagemaker")
@@ -42,11 +42,11 @@ class HuggingFace(Framework):
4242
def __init__(
4343
self,
4444
py_version: str,
45-
entry_point: str,
46-
transformers_version: Optional[int] = None,
47-
tensorflow_version: Optional[int] = None,
48-
pytorch_version: Optional[int] = None,
49-
source_dir: Optional[str] = None,
45+
entry_point: Union[str, PipelineVariable],
46+
transformers_version: Optional[str] = None,
47+
tensorflow_version: Optional[str] = None,
48+
pytorch_version: Optional[str] = None,
49+
source_dir: Optional[Union[str, PipelineVariable]] = None,
5050
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
5151
image_uri: Optional[Union[str, PipelineVariable]] = None,
5252
distribution: Optional[Dict] = None,
@@ -203,14 +203,8 @@ def __init__(
203203
f"Instead got {type(compiler_config)}"
204204
)
205205
raise ValueError(error_string)
206-
207-
compiler_config.validate(
208-
image_uri=image_uri,
209-
instance_type=instance_type,
210-
distribution=distribution,
211-
)
212-
213-
self.distribution = distribution or {}
206+
if compiler_config:
207+
compiler_config.validate(self)
214208
self.compiler_config = compiler_config
215209

216210
def _validate_args(self, image_uri):

src/sagemaker/inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
from __future__ import absolute_import, print_function
1515

1616
from typing import Union, Optional, List
17+
1718
import attr
1819

20+
from sagemaker.workflow.entities import PipelineVariable
21+
1922
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
2023
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2124

src/sagemaker/processing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4141
from sagemaker.apiutils._base_types import ApiObject
4242
from sagemaker.s3 import S3Uploader
43-
from sagemaker.workflow.entities import PipelineVariable
4443

4544

4645
logger = logging.getLogger(__name__)

src/sagemaker/sklearn/estimator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,6 @@ def __init__(
156156
)
157157

158158
if image_uri is None:
159-
160-
if is_pipeline_variable(instance_type):
161-
raise ValueError(
162-
"instance_type argument cannot be a pipeline variable when image_uri is not given."
163-
)
164-
165159
self.image_uri = image_uris.retrieve(
166160
SKLearn._framework_name,
167161
image_uri_region or self.sagemaker_session.boto_region_name,

src/sagemaker/tensorflow/training_compiler/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
"""Configuration for the SageMaker Training Compiler."""
1414
from __future__ import absolute_import
1515
import logging
16+
from typing import Union
17+
1618
from packaging.specifiers import SpecifierSet
1719
from packaging.version import Version
1820

1921
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
22+
from sagemaker.workflow.entities import PipelineVariable
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -29,8 +32,8 @@ class TrainingCompilerConfig(BaseConfig):
2932

3033
def __init__(
3134
self,
32-
enabled=True,
33-
debug=False,
35+
enabled: Union[bool, PipelineVariable] = True,
36+
debug: Union[bool, PipelineVariable] = False,
3437
):
3538
"""This class initializes a ``TrainingCompilerConfig`` instance.
3639

src/sagemaker/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def unique_name_from_base(base, max_length=63):
9393

9494
def base_name_from_image(image, default_base_name=None):
9595
"""Extract the base name of the image to use as the 'algorithm name' for the job.
96+
9697
Args:
9798
image (str): Image name.
9899
default_base_name (str): The default base name

src/sagemaker/workflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def is_pipeline_variable(var: object) -> bool:
3434

3535
def is_pipeline_parameter_string(var: object) -> bool:
3636
"""Check if the variable is a pipeline parameter string
37+
3738
Args:
3839
var (object): The variable to be verified.
3940
Returns:

src/sagemaker/workflow/step_collections.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
import warnings
17-
from sagemaker.deprecations import deprecated_class
1817
from typing import List, Union, Optional
1918

2019
import attr

src/sagemaker/xgboost/processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ class XGBoostProcessor(FrameworkProcessor):
3333

3434
def __init__(
3535
self,
36-
framework_version: str, # New arg
36+
framework_version: str,
3737
role: str,
3838
instance_count: Union[int, PipelineVariable],
3939
instance_type: Union[str, PipelineVariable],
40-
py_version: str = "py3", # New kwarg
40+
py_version: str = "py3",
4141
image_uri: Optional[Union[str, PipelineVariable]] = None,
4242
command: Optional[List[str]] = None,
4343
volume_size_in_gb: Union[int, PipelineVariable] = 30,
4444
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
4545
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
46-
code_location: Optional[str] = None, # New arg
46+
code_location: Optional[str] = None,
4747
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
4848
base_job_name: Optional[str] = None,
4949
sagemaker_session: Optional[Session] = None,

0 commit comments

Comments
 (0)