Skip to content

Commit a4c67e8

Browse files
committed
Use CuratedHub class instead of using a middleware utilities
1 parent 0435ae9 commit a4c67e8

File tree

2 files changed

+6
-35
lines changed

2 files changed

+6
-35
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from sagemaker.jumpstart import utils
5353
from sagemaker.jumpstart.curated_hub import utils as hub_utils
54+
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
5455
from sagemaker.utilities.cache import LRUCache
5556

5657

@@ -346,9 +347,8 @@ def _retrieval_function(
346347
hub_name, region, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
347348
id_info
348349
)
349-
hub_model_description: DescribeHubContentsResponse = hub_utils.describe_model(
350-
hub_name=hub_name,
351-
region=region,
350+
hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region)
351+
hub_model_description: DescribeHubContentsResponse = hub.describe_model(
352352
model_name=model_name,
353353
model_version=model_version
354354
)
@@ -364,10 +364,8 @@ def _retrieval_function(
364364
)
365365
if data_type == HubContentType.HUB:
366366
hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
367-
hub_description: DescribeHubResponse = hub_utils.describe(
368-
hub_name=hub_name,
369-
region=region
370-
)
367+
hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region)
368+
hub_description: DescribeHubResponse = hub.describe()
371369
return JumpStartCachedContentValue(formatted_content=hub_description)
372370
raise ValueError(
373371
f"Bad value for key '{key}': must be in",

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,16 @@
1313
"""This module contains utilities related to SageMaker JumpStart CuratedHub."""
1414
from __future__ import absolute_import
1515
import re
16-
from typing import Any, Dict, Optional
17-
import boto3
16+
from typing import Optional
1817
from sagemaker.session import Session
1918
from sagemaker.jumpstart import constants
2019
from sagemaker.utils import aws_partition
2120
from sagemaker.jumpstart.curated_hub.types import (
22-
DescribeHubResponse,
2321
HubContentType,
24-
DescribeHubContentsResponse,
2522
HubArnExtractedInfo,
2623
)
2724

2825

29-
def describe(hub_name: str, region: str) -> DescribeHubResponse:
30-
"""Returns descriptive information about the Hub."""
31-
32-
sagemaker_session = Session(boto3.Session(region_name=region))
33-
hub_description = sagemaker_session.describe_hub(hub_name=hub_name)
34-
return DescribeHubResponse(hub_description)
35-
36-
37-
def describe_model(
38-
hub_name: str, region: str, model_name: str, model_version: str = "*"
39-
) -> DescribeHubContentsResponse:
40-
"""Returns descriptive information about the Hub model."""
41-
42-
sagemaker_session = Session(boto3.Session(region_name=region))
43-
hub_content_description: Dict[str, Any] = sagemaker_session.describe_hub_content(
44-
hub_name=hub_name,
45-
hub_content_name=model_name,
46-
hub_content_version=model_version,
47-
hub_content_type=HubContentType.MODEL,
48-
)
49-
50-
return DescribeHubContentsResponse(hub_content_description)
51-
52-
5326
def get_info_from_hub_resource_arn(
5427
arn: str,
5528
) -> HubArnExtractedInfo:

0 commit comments

Comments
 (0)