Skip to content

Commit 2973f23

Browse files
committed
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (aws#4438)
1 parent 347b599 commit 2973f23

File tree

6 files changed

+54
-1
lines changed

6 files changed

+54
-1
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
DescribeHubContentsResponse,
5959
HubType,
6060
HubContentType,
61+
HubDataType,
6162
)
6263
from sagemaker.jumpstart.curated_hub import utils as hub_utils
6364
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -429,6 +430,7 @@ def _retrieval_function(
429430
"""
430431

431432
data_type, id_info = key.data_type, key.id_info
433+
432434
if data_type in {
433435
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
434436
JumpStartS3FileType.PROPRIETARY_MANIFEST,

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
175+
# works cross-partition
176+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
176177
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
177178

178179
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,25 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
972972
"""
973973
# TODO: Implement
974974

975+
def to_json(self) -> Dict[str, Any]:
976+
"""Returns json representation of JumpStartModelSpecs object."""
977+
json_obj = {}
978+
for att in self.__slots__:
979+
if hasattr(self, att):
980+
cur_val = getattr(self, att)
981+
if issubclass(type(cur_val), JumpStartDataHolderType):
982+
json_obj[att] = cur_val.to_json()
983+
elif isinstance(cur_val, list):
984+
json_obj[att] = []
985+
for obj in cur_val:
986+
if issubclass(type(obj), JumpStartDataHolderType):
987+
json_obj[att].append(obj.to_json())
988+
else:
989+
json_obj[att].append(obj)
990+
else:
991+
json_obj[att] = cur_val
992+
return json_obj
993+
975994
def supports_prepacked_inference(self) -> bool:
976995
"""Returns True if the model has a prepacked inference artifact."""
977996
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
19+
import re
1920
from urllib.parse import urlparse
2021
import boto3
2122
from packaging.version import Version

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,35 @@ def test_mime_type_enum_from_str():
12141214
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12151215

12161216

1217+
def test_extract_info_from_hub_content_arn():
1218+
model_arn = (
1219+
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1220+
)
1221+
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1222+
"MockHub",
1223+
"us-west-2",
1224+
"my-mock-model",
1225+
"1.0.2",
1226+
)
1227+
1228+
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1229+
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1230+
1231+
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1232+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1233+
1234+
invalid_arn = "nonsense-string"
1235+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236+
1237+
invalid_arn = ""
1238+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1239+
1240+
invalid_arn = (
1241+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1242+
)
1243+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1244+
1245+
12171246
class TestIsValidModelId(TestCase):
12181247
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12191248
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25+
HubDataType,
2526
JumpStartCachedContentKey,
2627
JumpStartCachedContentValue,
2728
JumpStartModelSpecs,

0 commit comments

Comments
 (0)