Skip to content

Commit 4ae201c

Browse files
committed
add hub and hubcontent support in retrieval function for jumpstart model cache
1 parent 1e26760 commit 4ae201c

File tree

4 files changed

+46
-5
lines changed

4 files changed

+46
-5
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,26 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
465465
)
466466
)
467467
return specs.formatted_content
468+
469+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
470+
"""Return JumpStart-compatible specs for a given Hub model
471+
472+
Args:
473+
hub_model_arn (str): Arn for the Hub model to get specs for
474+
"""
475+
476+
specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn))
477+
return specs.formatted_content
478+
479+
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
480+
"""Return descriptive info for a given Hub
481+
482+
Args:
483+
hub_arn (str): Arn for the Hub to get info for
484+
"""
485+
486+
manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
487+
return manifest.formatted_content
468488

469489
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
470490
"""Return JumpStart-compatible specs for a given Hub model
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from botocore.config import Config
14+
15+
DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"})

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from sagemaker.session import Session
2020

21+
from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG
22+
2123

2224
class CuratedHub:
2325
"""Class for creating and managing a curated JumpStart hub"""

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -860,12 +860,16 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
860860
def test_jumpstart_cache_get_hub_model():
861861
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
862862

863-
model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3"
863+
model_arn = (
864+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3"
865+
)
864866
assert get_spec_from_base_spec(
865867
model_id="huggingface-mock-model-123", version="1.2.3"
866868
) == cache.get_hub_model(hub_model_arn=model_arn)
867869

868-
model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*"
869-
assert get_spec_from_base_spec(
870-
model_id="pytorch-mock-model-123", version="*"
871-
) == cache.get_hub_model(hub_model_arn=model_arn)
870+
model_arn = (
871+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*"
872+
)
873+
assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model(
874+
hub_model_arn=model_arn
875+
)

0 commit comments

Comments
 (0)