Skip to content

fix: js tagging s3 prefix #4167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 10, 2023
10 changes: 7 additions & 3 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from packaging.version import Version
import sagemaker
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model(

def add_jumpstart_tags(
tags: Optional[List[Dict[str, str]]] = None,
inference_model_uri: Optional[str] = None,
inference_model_uri: Optional[Union[str, dict]] = None,
inference_script_uri: Optional[str] = None,
training_model_uri: Optional[str] = None,
training_script_uri: Optional[str] = None,
Expand All @@ -289,7 +289,7 @@ def add_jumpstart_tags(
Args:
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
or training job. (Default: None).
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact.
(Default: None).
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
(Default: None).
Expand All @@ -302,6 +302,10 @@ def add_jumpstart_tags(
"The URI (%s) is a pipeline variable which is only interpreted at execution time. "
"As a result, the JumpStart resources will not be tagged."
)

if isinstance(inference_model_uri, dict):
inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None)

if inference_model_uri:
if is_pipeline_variable(inference_model_uri):
logging.warning(warn_msg, "inference_model_uri")
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,9 @@ def deploy(

tags = add_jumpstart_tags(
tags=tags,
inference_model_uri=self.model_data if isinstance(self.model_data, str) else None,
inference_model_uri=self.model_data
if isinstance(self.model_data, (str, dict))
else None,
inference_script_uri=self.source_dir,
)

Expand Down
28 changes: 28 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,34 @@ def test_add_jumpstart_tags_inference():
inference_script_uri=inference_script_uri,
) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}]

tags = []
inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}}
inference_script_uri = "dfsdfs"
assert utils.add_jumpstart_tags(
tags=tags,
inference_model_uri=inference_model_uri,
inference_script_uri=inference_script_uri,
) == [
{
"Key": JumpStartTag.INFERENCE_MODEL_URI.value,
"Value": inference_model_uri["S3DataSource"]["S3Uri"],
}
]

tags = []
inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key/prefix/")}}
inference_script_uri = "dfsdfs"
assert utils.add_jumpstart_tags(
tags=tags,
inference_model_uri=inference_model_uri,
inference_script_uri=inference_script_uri,
) == [
{
"Key": JumpStartTag.INFERENCE_MODEL_URI.value,
"Value": inference_model_uri["S3DataSource"]["S3Uri"],
}
]

tags = [{"Key": "some", "Value": "tag"}]
inference_model_uri = random_jumpstart_s3_uri("random_key")
inference_script_uri = "dfsdfs"
Expand Down