Skip to content

Commit ddf5bae

Browse files
committed
feat: jumpstart estimator enable infra check flag
1 parent 41feb4c commit ddf5bae

File tree

5 files changed

+12
-1
lines changed

5 files changed

+12
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,6 +3082,7 @@ def __init__(
30823082
hyperparameters=hyperparameters,
30833083
instance_groups=instance_groups,
30843084
training_repository_access_mode=training_repository_access_mode,
3085+
enable_infra_check=enable_infra_check,
30853086
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
30863087
container_entry_point=container_entry_point,
30873088
container_arguments=container_arguments,

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
instance_groups: Optional[List[InstanceGroup]] = None,
103103
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
104104
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
105+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
105106
container_entry_point: Optional[List[str]] = None,
106107
container_arguments: Optional[List[str]] = None,
107108
disable_output_compression: Optional[bool] = None,
@@ -485,6 +486,8 @@ def __init__(
485486
private Docker registry where your training image is hosted (Default: None).
486487
When it's set to None, SageMaker will not do authentication before pulling the image
487488
in the private Docker registry. (Default: None).
489+
enable_infra_check (Optional[Union[bool, PipelineVariable]]):
490+
Specifies whether it is running Sagemaker built-in infra check jobs.
488491
container_entry_point (Optional[List[str]]): The entrypoint script for a Docker
489492
container used to run a training job. This script takes precedence over
490493
the default train processing instructions.
@@ -565,6 +568,7 @@ def _is_valid_model_id_hook():
565568
container_entry_point=container_entry_point,
566569
container_arguments=container_arguments,
567570
disable_output_compression=disable_output_compression,
571+
enable_infra_check=enable_infra_check,
568572
)
569573

570574
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_init_kwargs(
123123
container_entry_point: Optional[List[str]] = None,
124124
container_arguments: Optional[List[str]] = None,
125125
disable_output_compression: Optional[bool] = None,
126+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
126127
) -> JumpStartEstimatorInitKwargs:
127128
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
128129

@@ -178,6 +179,7 @@ def get_init_kwargs(
178179
container_entry_point=container_entry_point,
179180
container_arguments=container_arguments,
180181
disable_output_compression=disable_output_compression,
182+
enable_infra_check=enable_infra_check,
181183
)
182184

183185
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.utils import get_instance_type_family
1919

2020
from sagemaker.session import Session
21+
from sagemaker.workflow.entities import PipelineVariable
2122

2223

2324
class JumpStartDataHolderType:
@@ -946,6 +947,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
946947
"container_entry_point",
947948
"container_arguments",
948949
"disable_output_compression",
950+
"enable_infra_check",
949951
]
950952

951953
SERIALIZATION_EXCLUSION_SET = {
@@ -1009,6 +1011,7 @@ def __init__(
10091011
container_entry_point: Optional[List[str]] = None,
10101012
container_arguments: Optional[List[str]] = None,
10111013
disable_output_compression: Optional[bool] = None,
1014+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
10121015
) -> None:
10131016
"""Instantiates JumpStartEstimatorInitKwargs object."""
10141017

@@ -1065,6 +1068,7 @@ def __init__(
10651068
self.container_entry_point = container_entry_point
10661069
self.container_arguments = container_arguments
10671070
self.disable_output_compression = disable_output_compression
1071+
self.enable_infra_check = enable_infra_check
10681072

10691073

10701074
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
677677
Please add the new argument to the skip set below,
678678
and reach out to JumpStart team."""
679679

680-
init_args_to_skip: Set[str] = set(["kwargs", "enable_infra_check"])
680+
init_args_to_skip: Set[str] = set(["kwargs"])
681681
fit_args_to_skip: Set[str] = set()
682682
deploy_args_to_skip: Set[str] = set(["kwargs"])
683683

0 commit comments

Comments
 (0)