Skip to content

Commit 78e7b3b

Browse files
feat: jumpstart estimator enable infra check flag (#4154)
Co-authored-by: martinRenou <[email protected]>
1 parent 858a965 commit 78e7b3b

File tree

6 files changed

+13
-2
lines changed

6 files changed

+13
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,6 +3082,7 @@ def __init__(
30823082
hyperparameters=hyperparameters,
30833083
instance_groups=instance_groups,
30843084
training_repository_access_mode=training_repository_access_mode,
3085+
enable_infra_check=enable_infra_check,
30853086
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
30863087
container_entry_point=container_entry_point,
30873088
container_arguments=container_arguments,

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
instance_groups: Optional[List[InstanceGroup]] = None,
103103
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
104104
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
105+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
105106
container_entry_point: Optional[List[str]] = None,
106107
container_arguments: Optional[List[str]] = None,
107108
disable_output_compression: Optional[bool] = None,
@@ -485,6 +486,8 @@ def __init__(
485486
private Docker registry where your training image is hosted (Default: None).
486487
When it's set to None, SageMaker will not do authentication before pulling the image
487488
in the private Docker registry. (Default: None).
489+
enable_infra_check (Optional[Union[bool, PipelineVariable]]):
490+
Specifies whether it is running Sagemaker built-in infra check jobs.
488491
container_entry_point (Optional[List[str]]): The entrypoint script for a Docker
489492
container used to run a training job. This script takes precedence over
490493
the default train processing instructions.
@@ -565,6 +568,7 @@ def _is_valid_model_id_hook():
565568
container_entry_point=container_entry_point,
566569
container_arguments=container_arguments,
567570
disable_output_compression=disable_output_compression,
571+
enable_infra_check=enable_infra_check,
568572
)
569573

570574
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_init_kwargs(
123123
container_entry_point: Optional[List[str]] = None,
124124
container_arguments: Optional[List[str]] = None,
125125
disable_output_compression: Optional[bool] = None,
126+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
126127
) -> JumpStartEstimatorInitKwargs:
127128
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
128129

@@ -178,6 +179,7 @@ def get_init_kwargs(
178179
container_entry_point=container_entry_point,
179180
container_arguments=container_arguments,
180181
disable_output_compression=disable_output_compression,
182+
enable_infra_check=enable_infra_check,
181183
)
182184

183185
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.utils import get_instance_type_family
1919

2020
from sagemaker.session import Session
21+
from sagemaker.workflow.entities import PipelineVariable
2122

2223

2324
class JumpStartDataHolderType:
@@ -1011,6 +1012,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
10111012
"container_entry_point",
10121013
"container_arguments",
10131014
"disable_output_compression",
1015+
"enable_infra_check",
10141016
]
10151017

10161018
SERIALIZATION_EXCLUSION_SET = {
@@ -1074,6 +1076,7 @@ def __init__(
10741076
container_entry_point: Optional[List[str]] = None,
10751077
container_arguments: Optional[List[str]] = None,
10761078
disable_output_compression: Optional[bool] = None,
1079+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
10771080
) -> None:
10781081
"""Instantiates JumpStartEstimatorInitKwargs object."""
10791082

@@ -1130,6 +1133,7 @@ def __init__(
11301133
self.container_entry_point = container_entry_point
11311134
self.container_arguments = container_arguments
11321135
self.disable_output_compression = disable_output_compression
1136+
self.enable_infra_check = enable_infra_check
11331137

11341138

11351139
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
677677
Please add the new argument to the skip set below,
678678
and reach out to JumpStart team."""
679679

680-
init_args_to_skip: Set[str] = set(["kwargs", "enable_infra_check"])
680+
init_args_to_skip: Set[str] = set(["kwargs"])
681681
fit_args_to_skip: Set[str] = set()
682682
deploy_args_to_skip: Set[str] = set(["kwargs"])
683683

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
331331

332332
"""If you add arguments to <Model constructor>, this test will fail.
333333
Please add the new argument to the skip set below,
334-
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
334+
and reach out to JumpStart team."""
335335

336336
init_args_to_skip: Set[str] = set()
337337
deploy_args_to_skip: Set[str] = set(["kwargs"])

0 commit comments

Comments
 (0)