Skip to content

Commit 70f6589

Browse files
committed
feat: Adding Marketplace model support for HubService (aws#1539)
* feat: Adding BRS support
1 parent 74c624a commit 70f6589

File tree

13 files changed

+3532
-35
lines changed

13 files changed

+3532
-35
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class JumpStartTag(str, Enum):
9393
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9494
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
9595

96+
INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
97+
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
98+
99+
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
100+
96101

97102
class SerializerType(str, Enum):
98103
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/factory/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@
3737
from sagemaker.model_metrics import ModelMetrics
3838
from sagemaker.metadata_properties import MetadataProperties
3939
from sagemaker.drift_check_baselines import DriftCheckBaselines
40-
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
40+
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
4141
from sagemaker.jumpstart.types import (
4242
JumpStartModelDeployKwargs,
4343
JumpStartModelInitKwargs,
4444
JumpStartModelRegisterKwargs,
4545
)
4646
from sagemaker.jumpstart.utils import (
47-
add_jumpstart_model_id_version_tags,
47+
add_hub_content_arn_tags,
48+
add_jumpstart_model_info_tags,
49+
get_default_jumpstart_session_with_user_agent_suffix,
50+
get_neo_content_bucket,
51+
get_top_ranked_config_name,
4852
update_dict_if_key_not_present,
4953
resolve_model_sagemaker_config_field,
5054
verify_model_region_and_return_specs,
@@ -483,6 +487,17 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
483487
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
484488
)
485489

490+
if kwargs.hub_arn:
491+
if kwargs.model_reference_arn:
492+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
493+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
494+
)
495+
else:
496+
hub_content_arn = construct_hub_model_arn_from_inputs(
497+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
498+
)
499+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
500+
486501
return kwargs
487502

488503

0 commit comments

Comments
 (0)