Skip to content

Commit e3f0312

Browse files
committed
fix: js tagging s3 prefix
1 parent f631e41 commit e3f0312

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Optional
17+
from typing import Any, Dict, List, Optional, Union
1818
from urllib.parse import urlparse
1919
from packaging.version import Version
2020
import sagemaker
@@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model(
277277

278278
def add_jumpstart_tags(
279279
tags: Optional[List[Dict[str, str]]] = None,
280-
inference_model_uri: Optional[str] = None,
280+
inference_model_uri: Optional[Union[str, dict]] = None,
281281
inference_script_uri: Optional[str] = None,
282282
training_model_uri: Optional[str] = None,
283283
training_script_uri: Optional[str] = None,
@@ -289,7 +289,7 @@ def add_jumpstart_tags(
289289
Args:
290290
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
291291
or training job. (Default: None).
292-
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
292+
inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact.
293293
(Default: None).
294294
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
295295
(Default: None).
@@ -302,6 +302,10 @@ def add_jumpstart_tags(
302302
"The URI (%s) is a pipeline variable which is only interpreted at execution time. "
303303
"As a result, the JumpStart resources will not be tagged."
304304
)
305+
306+
if isinstance(inference_model_uri, dict):
307+
inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None)
308+
305309
if inference_model_uri:
306310
if is_pipeline_variable(inference_model_uri):
307311
logging.warning(warn_msg, "inference_model_uri")

src/sagemaker/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1345,7 +1345,9 @@ def deploy(
13451345

13461346
tags = add_jumpstart_tags(
13471347
tags=tags,
1348-
inference_model_uri=self.model_data if isinstance(self.model_data, str) else None,
1348+
inference_model_uri=self.model_data
1349+
if isinstance(self.model_data, (str, dict))
1350+
else None,
13491351
inference_script_uri=self.source_dir,
13501352
)
13511353

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ def test_add_jumpstart_tags_inference():
216216
inference_script_uri=inference_script_uri,
217217
) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}]
218218

219+
220+
tags = []
221+
inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}}
222+
inference_script_uri = "dfsdfs"
223+
assert utils.add_jumpstart_tags(
224+
tags=tags,
225+
inference_model_uri=inference_model_uri,
226+
inference_script_uri=inference_script_uri,
227+
) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri["S3DataSource"]["S3Uri"]}]
228+
219229
tags = [{"Key": "some", "Value": "tag"}]
220230
inference_model_uri = random_jumpstart_s3_uri("random_key")
221231
inference_script_uri = "dfsdfs"

0 commit comments

Comments
 (0)