Skip to content

Commit b376a67

Browse files
authored
Merge branch 'master' into feat/inference-instance-type-conditioned-on-training-instance-type
2 parents 727989a + 866a2d9 commit b376a67

File tree

15 files changed

+717
-103
lines changed

15 files changed

+717
-103
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from sagemaker.utils import instance_supports_kms
7272
from sagemaker.job import _Job
7373
from sagemaker.jumpstart.utils import (
74-
add_jumpstart_tags,
74+
add_jumpstart_uri_tags,
7575
get_jumpstart_base_name_if_jumpstart_model,
7676
update_inference_tags_with_jumpstart_training_tags,
7777
)
@@ -577,9 +577,7 @@ def __init__(
577577
self.entry_point = entry_point
578578
self.dependencies = dependencies or []
579579
self.uploaded_code: Optional[UploadedCode] = None
580-
self.tags = add_jumpstart_tags(
581-
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
582-
)
580+
583581
if self.instance_type in ("local", "local_gpu"):
584582
if self.instance_type == "local_gpu" and self.instance_count > 1:
585583
raise RuntimeError("Distributed Training in Local GPU is not supported")
@@ -592,6 +590,15 @@ def __init__(
592590
else:
593591
self.sagemaker_session = sagemaker_session or Session()
594592

593+
self.tags = (
594+
add_jumpstart_uri_tags(
595+
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
596+
)
597+
if getattr(self.sagemaker_session, "settings", None) is not None
598+
and self.sagemaker_session.settings.include_jumpstart_tags
599+
else tags
600+
)
601+
595602
self.base_job_name = base_job_name
596603
self._current_job_name = None
597604
if (

src/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@
154154
if region.gated_content_bucket is not None
155155
}
156156

157+
JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET = JUMPSTART_BUCKET_NAME_SET.union(
158+
JUMPSTART_GATED_BUCKET_NAME_SET
159+
)
160+
157161
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
158162

159163
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

src/sagemaker/jumpstart/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class JumpStartTag(str, Enum):
7676
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
7777
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"
7878

79+
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
80+
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
81+
7982

8083
class SerializerType(str, Enum):
8184
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@
6060
JumpStartModelInitKwargs,
6161
)
6262
from sagemaker.jumpstart.utils import (
63+
add_jumpstart_model_id_version_tags,
6364
update_dict_if_key_not_present,
6465
resolve_estimator_sagemaker_config_field,
66+
verify_model_region_and_return_specs,
6567
)
6668

6769

@@ -196,6 +198,7 @@ def get_init_kwargs(
196198
estimator_init_kwargs = _add_estimator_extra_kwargs(estimator_init_kwargs)
197199
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
198200
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
201+
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
199202

200203
return estimator_init_kwargs
201204

@@ -441,6 +444,26 @@ def _add_instance_type_and_count_to_kwargs(
441444
return kwargs
442445

443446

447+
def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
448+
"""Sets tags in kwargs based on default or override, returns full kwargs."""
449+
450+
full_model_version = verify_model_region_and_return_specs(
451+
model_id=kwargs.model_id,
452+
version=kwargs.model_version,
453+
scope=JumpStartScriptScope.TRAINING,
454+
region=kwargs.region,
455+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
456+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
457+
sagemaker_session=kwargs.sagemaker_session,
458+
).version
459+
460+
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
461+
kwargs.tags = add_jumpstart_model_id_version_tags(
462+
kwargs.tags, kwargs.model_id, full_model_version
463+
)
464+
return kwargs
465+
466+
444467
def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
445468
"""Sets image uri in kwargs based on default or override, returns full kwargs."""
446469

src/sagemaker/jumpstart/factory/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
JumpStartModelRegisterKwargs,
4545
)
4646
from sagemaker.jumpstart.utils import (
47+
add_jumpstart_model_id_version_tags,
4748
update_dict_if_key_not_present,
4849
resolve_model_sagemaker_config_field,
4950
verify_model_region_and_return_specs,
@@ -423,6 +424,27 @@ def _add_model_name_to_kwargs(
423424
return kwargs
424425

425426

427+
def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
428+
"""Sets tags based on default or override, returns full kwargs."""
429+
430+
full_model_version = verify_model_region_and_return_specs(
431+
model_id=kwargs.model_id,
432+
version=kwargs.model_version,
433+
scope=JumpStartScriptScope.INFERENCE,
434+
region=kwargs.region,
435+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
436+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
437+
sagemaker_session=kwargs.sagemaker_session,
438+
).version
439+
440+
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
441+
kwargs.tags = add_jumpstart_model_id_version_tags(
442+
kwargs.tags, kwargs.model_id, full_model_version
443+
)
444+
445+
return kwargs
446+
447+
426448
def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]:
427449
"""Sets extra kwargs based on default or override, returns full kwargs."""
428450

@@ -510,6 +532,8 @@ def get_deploy_kwargs(
510532

511533
deploy_kwargs = _add_deploy_extra_kwargs(kwargs=deploy_kwargs)
512534

535+
deploy_kwargs = _add_tags_to_kwargs(kwargs=deploy_kwargs)
536+
513537
return deploy_kwargs
514538

515539

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ def deploy(
552552
container_startup_health_check_timeout=container_startup_health_check_timeout,
553553
inference_recommendation_id=inference_recommendation_id,
554554
explainer_config=explainer_config,
555+
sagemaker_session=self.sagemaker_session,
555556
)
556557

557558
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/utils.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
250250
if urlparse(uri).scheme == "s3":
251251
bucket, _ = parse_s3_url(uri)
252252

253-
return bucket in constants.JUMPSTART_BUCKET_NAME_SET
253+
return bucket in constants.JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET
254254

255255

256256
def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool:
@@ -287,7 +287,10 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
287287

288288

289289
def add_single_jumpstart_tag(
290-
uri: str, tag_key: enums.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
290+
tag_value: str,
291+
tag_key: enums.JumpStartTag,
292+
curr_tags: Optional[List[Dict[str, str]]],
293+
is_uri=False,
291294
) -> Optional[List]:
292295
"""Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
293296
@@ -296,17 +299,28 @@ def add_single_jumpstart_tag(
296299
tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
297300
corresponds to a JumpStart model.
298301
curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
302+
is_uri (boolean): Set to True to indicate a s3 uri is to be tagged. Set to False to indicate
303+
tags for JumpStart model id / version are being added. (Default: False).
299304
"""
300-
if is_jumpstart_model_uri(uri):
305+
if not is_uri or is_jumpstart_model_uri(tag_value):
301306
if curr_tags is None:
302307
curr_tags = []
303308
if not tag_key_in_array(tag_key, curr_tags):
304-
curr_tags.append(
305-
{
306-
"Key": tag_key,
307-
"Value": uri,
308-
}
309+
skip_adding_tag = (
310+
(
311+
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
312+
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
313+
)
314+
if is_uri
315+
else False
309316
)
317+
if not skip_adding_tag:
318+
curr_tags.append(
319+
{
320+
"Key": tag_key,
321+
"Value": tag_value,
322+
}
323+
)
310324
return curr_tags
311325

312326

@@ -326,14 +340,37 @@ def get_jumpstart_base_name_if_jumpstart_model(
326340
return None
327341

328342

329-
def add_jumpstart_tags(
343+
def add_jumpstart_model_id_version_tags(
344+
tags: Optional[List[Dict[str, str]]],
345+
model_id: str,
346+
model_version: str,
347+
) -> List[Dict[str, str]]:
348+
"""Add custom model ID and version tags to JumpStart related resources."""
349+
if model_id is None or model_version is None:
350+
return tags
351+
tags = add_single_jumpstart_tag(
352+
model_id,
353+
enums.JumpStartTag.MODEL_ID,
354+
tags,
355+
is_uri=False,
356+
)
357+
tags = add_single_jumpstart_tag(
358+
model_version,
359+
enums.JumpStartTag.MODEL_VERSION,
360+
tags,
361+
is_uri=False,
362+
)
363+
return tags
364+
365+
366+
def add_jumpstart_uri_tags(
330367
tags: Optional[List[Dict[str, str]]] = None,
331368
inference_model_uri: Optional[Union[str, dict]] = None,
332369
inference_script_uri: Optional[str] = None,
333370
training_model_uri: Optional[str] = None,
334371
training_script_uri: Optional[str] = None,
335372
) -> Optional[List[Dict[str, str]]]:
336-
"""Add custom tags to JumpStart models, return the updated tags.
373+
"""Add custom uri tags to JumpStart models, return the updated tags.
337374
338375
No-op if this is not a JumpStart model related resource.
339376
@@ -362,31 +399,43 @@ def add_jumpstart_tags(
362399
logging.warning(warn_msg, "inference_model_uri")
363400
else:
364401
tags = add_single_jumpstart_tag(
365-
inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags
402+
inference_model_uri,
403+
enums.JumpStartTag.INFERENCE_MODEL_URI,
404+
tags,
405+
is_uri=True,
366406
)
367407

368408
if inference_script_uri:
369409
if is_pipeline_variable(inference_script_uri):
370410
logging.warning(warn_msg, "inference_script_uri")
371411
else:
372412
tags = add_single_jumpstart_tag(
373-
inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags
413+
inference_script_uri,
414+
enums.JumpStartTag.INFERENCE_SCRIPT_URI,
415+
tags,
416+
is_uri=True,
374417
)
375418

376419
if training_model_uri:
377420
if is_pipeline_variable(training_model_uri):
378421
logging.warning(warn_msg, "training_model_uri")
379422
else:
380423
tags = add_single_jumpstart_tag(
381-
training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags
424+
training_model_uri,
425+
enums.JumpStartTag.TRAINING_MODEL_URI,
426+
tags,
427+
is_uri=True,
382428
)
383429

384430
if training_script_uri:
385431
if is_pipeline_variable(training_script_uri):
386432
logging.warning(warn_msg, "training_script_uri")
387433
else:
388434
tags = add_single_jumpstart_tag(
389-
training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags
435+
training_script_uri,
436+
enums.JumpStartTag.TRAINING_SCRIPT_URI,
437+
tags,
438+
is_uri=True,
390439
)
391440

392441
return tags

src/sagemaker/model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from sagemaker.serverless import ServerlessInferenceConfig
5555
from sagemaker.transformer import Transformer
5656
from sagemaker.jumpstart.utils import (
57-
add_jumpstart_tags,
57+
add_jumpstart_uri_tags,
5858
get_jumpstart_base_name_if_jumpstart_model,
5959
)
6060
from sagemaker.utils import (
@@ -1375,13 +1375,17 @@ def deploy(
13751375
sagemaker_session=self.sagemaker_session,
13761376
)
13771377

1378-
tags = add_jumpstart_tags(
1379-
tags=tags,
1380-
inference_model_uri=self.model_data
1381-
if isinstance(self.model_data, (str, dict))
1382-
else None,
1383-
inference_script_uri=self.source_dir,
1384-
)
1378+
if (
1379+
getattr(self.sagemaker_session, "settings", None) is not None
1380+
and self.sagemaker_session.settings.include_jumpstart_tags
1381+
):
1382+
tags = add_jumpstart_uri_tags(
1383+
tags=tags,
1384+
inference_model_uri=self.model_data
1385+
if isinstance(self.model_data, (str, dict))
1386+
else None,
1387+
inference_script_uri=self.source_dir,
1388+
)
13851389

13861390
if self.role is None:
13871391
raise ValueError("Role can not be null for deploying a model")

src/sagemaker/session_settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
self,
2323
encrypt_repacked_artifacts=True,
2424
local_download_dir=None,
25+
include_jumpstart_tags=True,
2526
) -> None:
2627
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.
2728
@@ -31,9 +32,12 @@ def __init__(
3132
is not provided (Default: True).
3233
local_download_dir (str): Optional. A path specifying the local directory
3334
for downloading artifacts. (Default: None).
35+
include_jumpstart_tags (bool): Optional. By default, if a JumpStart model is identified,
36+
it will receive special tags describing its properties.
3437
"""
3538
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
3639
self._local_download_dir = local_download_dir
40+
self._include_jumpstart_tags = include_jumpstart_tags
3741

3842
@property
3943
def encrypt_repacked_artifacts(self) -> bool:
@@ -44,3 +48,8 @@ def encrypt_repacked_artifacts(self) -> bool:
4448
def local_download_dir(self) -> str:
4549
"""Return path specifying the local directory for downloading artifacts."""
4650
return self._local_download_dir
51+
52+
@property
53+
def include_jumpstart_tags(self) -> bool:
54+
"""Return True if JumpStart tags should be attached to models with JumpStart artifacts."""
55+
return self._include_jumpstart_tags

src/sagemaker/tuner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from sagemaker.inputs import TrainingInput, FileSystemInput
3535
from sagemaker.job import _Job
3636
from sagemaker.jumpstart.utils import (
37-
add_jumpstart_tags,
37+
add_jumpstart_uri_tags,
3838
get_jumpstart_base_name_if_jumpstart_model,
3939
)
4040
from sagemaker.parameter import (
@@ -816,11 +816,12 @@ def _prepare_tags_for_tuning(self):
816816
if tag not in self.tags:
817817
self.tags.append(tag)
818818

819-
self.tags = add_jumpstart_tags(
820-
tags=self.tags,
821-
training_script_uri=getattr(estimator, "source_dir", None),
822-
training_model_uri=self._get_model_uri(estimator),
823-
)
819+
if self.sagemaker_session.settings.include_jumpstart_tags:
820+
self.tags = add_jumpstart_uri_tags(
821+
tags=self.tags,
822+
training_script_uri=getattr(estimator, "source_dir", None),
823+
training_model_uri=self._get_model_uri(estimator),
824+
)
824825

825826
def _prepare_job_name_for_tuning(self, job_name=None):
826827
"""Set current job name before starting tuning."""

0 commit comments

Comments
 (0)