|
37 | 37 | from sagemaker.model_metrics import ModelMetrics
|
38 | 38 | from sagemaker.metadata_properties import MetadataProperties
|
39 | 39 | from sagemaker.drift_check_baselines import DriftCheckBaselines
|
40 |
| -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType |
| 40 | +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability |
41 | 41 | from sagemaker.jumpstart.types import (
|
42 | 42 | JumpStartModelDeployKwargs,
|
43 | 43 | JumpStartModelInitKwargs,
|
44 | 44 | JumpStartModelRegisterKwargs,
|
45 | 45 | )
|
46 | 46 | from sagemaker.jumpstart.utils import (
|
47 |
| - add_jumpstart_model_id_version_tags, |
| 47 | + add_hub_content_arn_tags, |
| 48 | + add_jumpstart_model_info_tags, |
| 49 | + get_default_jumpstart_session_with_user_agent_suffix, |
| 50 | + get_neo_content_bucket, |
| 51 | + get_top_ranked_config_name, |
48 | 52 | update_dict_if_key_not_present,
|
49 | 53 | resolve_model_sagemaker_config_field,
|
50 | 54 | verify_model_region_and_return_specs,
|
@@ -483,6 +487,17 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
|
483 | 487 | kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
|
484 | 488 | )
|
485 | 489 |
|
| 490 | + if kwargs.hub_arn: |
| 491 | + if kwargs.model_reference_arn: |
| 492 | + hub_content_arn = construct_hub_model_reference_arn_from_inputs( |
| 493 | + kwargs.hub_arn, kwargs.model_id, kwargs.model_version |
| 494 | + ) |
| 495 | + else: |
| 496 | + hub_content_arn = construct_hub_model_arn_from_inputs( |
| 497 | + kwargs.hub_arn, kwargs.model_id, kwargs.model_version |
| 498 | + ) |
| 499 | + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) |
| 500 | + |
486 | 501 | return kwargs
|
487 | 502 |
|
488 | 503 |
|
|
0 commit comments