Skip to content

Commit f5a392e

Browse files
Merge branch 'master' into neo-input-shape-derivation-comment
2 parents 714caec + 78e7b3b commit f5a392e

File tree

9 files changed

+51
-6
lines changed

9 files changed

+51
-6
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,6 +3082,7 @@ def __init__(
30823082
hyperparameters=hyperparameters,
30833083
instance_groups=instance_groups,
30843084
training_repository_access_mode=training_repository_access_mode,
3085+
enable_infra_check=enable_infra_check,
30853086
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
30863087
container_entry_point=container_entry_point,
30873088
container_arguments=container_arguments,

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
instance_groups: Optional[List[InstanceGroup]] = None,
103103
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
104104
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
105+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
105106
container_entry_point: Optional[List[str]] = None,
106107
container_arguments: Optional[List[str]] = None,
107108
disable_output_compression: Optional[bool] = None,
@@ -485,6 +486,8 @@ def __init__(
485486
private Docker registry where your training image is hosted (Default: None).
486487
When it's set to None, SageMaker will not do authentication before pulling the image
487488
in the private Docker registry. (Default: None).
489+
enable_infra_check (Optional[Union[bool, PipelineVariable]]):
490+
Specifies whether it is running Sagemaker built-in infra check jobs.
488491
container_entry_point (Optional[List[str]]): The entrypoint script for a Docker
489492
container used to run a training job. This script takes precedence over
490493
the default train processing instructions.
@@ -565,6 +568,7 @@ def _is_valid_model_id_hook():
565568
container_entry_point=container_entry_point,
566569
container_arguments=container_arguments,
567570
disable_output_compression=disable_output_compression,
571+
enable_infra_check=enable_infra_check,
568572
)
569573

570574
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_init_kwargs(
123123
container_entry_point: Optional[List[str]] = None,
124124
container_arguments: Optional[List[str]] = None,
125125
disable_output_compression: Optional[bool] = None,
126+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
126127
) -> JumpStartEstimatorInitKwargs:
127128
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
128129

@@ -178,6 +179,7 @@ def get_init_kwargs(
178179
container_entry_point=container_entry_point,
179180
container_arguments=container_arguments,
180181
disable_output_compression=disable_output_compression,
182+
enable_infra_check=enable_infra_check,
181183
)
182184

183185
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.utils import get_instance_type_family
1919

2020
from sagemaker.session import Session
21+
from sagemaker.workflow.entities import PipelineVariable
2122

2223

2324
class JumpStartDataHolderType:
@@ -1011,6 +1012,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
10111012
"container_entry_point",
10121013
"container_arguments",
10131014
"disable_output_compression",
1015+
"enable_infra_check",
10141016
]
10151017

10161018
SERIALIZATION_EXCLUSION_SET = {
@@ -1074,6 +1076,7 @@ def __init__(
10741076
container_entry_point: Optional[List[str]] = None,
10751077
container_arguments: Optional[List[str]] = None,
10761078
disable_output_compression: Optional[bool] = None,
1079+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
10771080
) -> None:
10781081
"""Instantiates JumpStartEstimatorInitKwargs object."""
10791082

@@ -1130,6 +1133,7 @@ def __init__(
11301133
self.container_entry_point = container_entry_point
11311134
self.container_arguments = container_arguments
11321135
self.disable_output_compression = disable_output_compression
1136+
self.enable_infra_check = enable_infra_check
11331137

11341138

11351139
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Dict, List, Optional, Union
1818
from urllib.parse import urlparse
1919
from packaging.version import Version
2020
import sagemaker
@@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model(
277277

278278
def add_jumpstart_tags(
279279
tags: Optional[List[Dict[str, str]]] = None,
280-
inference_model_uri: Optional[str] = None,
280+
inference_model_uri: Optional[Union[str, dict]] = None,
281281
inference_script_uri: Optional[str] = None,
282282
training_model_uri: Optional[str] = None,
283283
training_script_uri: Optional[str] = None,
@@ -289,7 +289,7 @@ def add_jumpstart_tags(
289289
Args:
290290
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
291291
or training job. (Default: None).
292-
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
292+
inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact.
293293
(Default: None).
294294
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
295295
(Default: None).
@@ -302,6 +302,10 @@ def add_jumpstart_tags(
302302
"The URI (%s) is a pipeline variable which is only interpreted at execution time. "
303303
"As a result, the JumpStart resources will not be tagged."
304304
)
305+
306+
if isinstance(inference_model_uri, dict):
307+
inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None)
308+
305309
if inference_model_uri:
306310
if is_pipeline_variable(inference_model_uri):
307311
logging.warning(warn_msg, "inference_model_uri")

src/sagemaker/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,9 @@ def deploy(
13471347

13481348
tags = add_jumpstart_tags(
13491349
tags=tags,
1350-
inference_model_uri=self.model_data if isinstance(self.model_data, str) else None,
1350+
inference_model_uri=self.model_data
1351+
if isinstance(self.model_data, (str, dict))
1352+
else None,
13511353
inference_script_uri=self.source_dir,
13521354
)
13531355

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
677677
Please add the new argument to the skip set below,
678678
and reach out to JumpStart team."""
679679

680-
init_args_to_skip: Set[str] = set(["kwargs", "enable_infra_check"])
680+
init_args_to_skip: Set[str] = set(["kwargs"])
681681
fit_args_to_skip: Set[str] = set()
682682
deploy_args_to_skip: Set[str] = set(["kwargs"])
683683

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
331331

332332
"""If you add arguments to <Model constructor>, this test will fail.
333333
Please add the new argument to the skip set below,
334-
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
334+
and reach out to JumpStart team."""
335335

336336
init_args_to_skip: Set[str] = set()
337337
deploy_args_to_skip: Set[str] = set(["kwargs"])

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,34 @@ def test_add_jumpstart_tags_inference():
216216
inference_script_uri=inference_script_uri,
217217
) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}]
218218

219+
tags = []
220+
inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}}
221+
inference_script_uri = "dfsdfs"
222+
assert utils.add_jumpstart_tags(
223+
tags=tags,
224+
inference_model_uri=inference_model_uri,
225+
inference_script_uri=inference_script_uri,
226+
) == [
227+
{
228+
"Key": JumpStartTag.INFERENCE_MODEL_URI.value,
229+
"Value": inference_model_uri["S3DataSource"]["S3Uri"],
230+
}
231+
]
232+
233+
tags = []
234+
inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key/prefix/")}}
235+
inference_script_uri = "dfsdfs"
236+
assert utils.add_jumpstart_tags(
237+
tags=tags,
238+
inference_model_uri=inference_model_uri,
239+
inference_script_uri=inference_script_uri,
240+
) == [
241+
{
242+
"Key": JumpStartTag.INFERENCE_MODEL_URI.value,
243+
"Value": inference_model_uri["S3DataSource"]["S3Uri"],
244+
}
245+
]
246+
219247
tags = [{"Key": "some", "Value": "tag"}]
220248
inference_model_uri = random_jumpstart_s3_uri("random_key")
221249
inference_script_uri = "dfsdfs"

0 commit comments

Comments
 (0)