Skip to content

Commit 45a4a4d

Browse files
committed
delete hubutils, refactor
1 parent 4922511 commit 45a4a4d

File tree

5 files changed

+92
-225
lines changed

5 files changed

+92
-225
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
JumpStartModelSpecs,
4545
JumpStartS3FileType,
4646
JumpStartVersionedModelId,
47-
HubDataType,
47+
HubContentType,
4848
)
4949
from sagemaker.jumpstart import utils
5050
from sagemaker.utilities.cache import LRUCache
@@ -338,7 +338,7 @@ def _retrieval_function(
338338
return JumpStartCachedContentValue(
339339
formatted_content=model_specs
340340
)
341-
if data_type == HubDataType.MODEL:
341+
if data_type == HubContentType.MODEL:
342342
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
343343
id_info
344344
)
@@ -353,14 +353,14 @@ def _retrieval_function(
353353
return JumpStartCachedContentValue(
354354
formatted_content=model_specs
355355
)
356-
if data_type == HubDataType.HUB:
356+
if data_type == HubContentType.HUB:
357357
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
358358
hub = CuratedHub(hub_name=hub_name, region=region)
359359
hub_info = hub.describe()
360360
return JumpStartCachedContentValue(formatted_content=hub_info)
361361
raise ValueError(
362362
f"Bad value for key '{key}': must be in",
363-
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
363+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
364364
)
365365

366366
def get_manifest(self) -> List[JumpStartModelHeader]:
@@ -474,7 +474,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
474474
"""
475475

476476
details, _ = self._content_cache.get(
477-
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
477+
JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn)
478478
)
479479
return details.formatted_content
480480

@@ -485,7 +485,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
485485
hub_arn (str): Arn for the Hub to get info for
486486
"""
487487

488-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
488+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
489489
return details.formatted_content
490490

491491
def clear(self) -> None:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
JUMPSTART_DEFAULT_REGION_NAME,
2121
)
2222

23-
from sagemaker.jumpstart.types import HubDataType
24-
import sagemaker.jumpstart.curated_hub.utils as hubutils
23+
from sagemaker.jumpstart.types import HubContentType
24+
import sagemaker.jumpstart.session_utils as session_utils
2525

2626

2727
class CuratedHub:
@@ -50,32 +50,60 @@ def create(
5050
) -> Dict[str, str]:
5151
"""Creates a hub with the given description"""
5252

53-
return hubutils.create_hub(
53+
bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist(
54+
bucket_name, self._session
55+
)
56+
57+
return self._session.create_hub(
5458
hub_name=self.name,
5559
hub_description=description,
5660
hub_display_name=display_name,
5761
hub_search_keywords=search_keywords,
5862
hub_bucket_name=bucket_name,
5963
tags=tags,
60-
sagemaker_session=self._session,
6164
)
6265

66+
def describe(self) -> Dict[str, Any]:
67+
"""Returns descriptive information about the Hub"""
68+
69+
hub_info = self._session.describe_hub(hub_name=self.name)
70+
71+
return hub_info
72+
73+
def list_models(self, **kwargs) -> Dict[str, Any]:
74+
"""Lists the models in this Curated Hub
75+
76+
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
77+
"""
78+
# TODO: Validate kwargs and fast-fail?
79+
80+
hub_content_summaries = self._session.list_hub_contents(
81+
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
82+
)
83+
# TODO: Handle pagination
84+
return hub_content_summaries
85+
6386
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
6487
"""Returns descriptive information about the Hub Model"""
6588

66-
hub_content = hubutils.describe_hub_content(
89+
hub_content = self._session.describe_hub_content(
6790
hub_name=self.name,
68-
content_name=model_name,
69-
content_type=HubDataType.MODEL,
70-
content_version=model_version,
71-
sagemaker_session=self._session,
91+
hub_content_name=model_name,
92+
hub_content_version=model_version,
93+
hub_content_type=HubContentType.MODEL,
7294
)
7395

7496
return hub_content
7597

76-
def describe(self) -> Dict[str, Any]:
77-
"""Returns descriptive information about the Hub"""
78-
79-
hub_info = hubutils.describe_hub(hub_name=self.name, sagemaker_session=self._session)
98+
def delete_model(self, model_name: str, model_version: str = "*") -> None:
99+
"""Deletes a model from this CuratedHub."""
100+
return self._session.delete_hub_content(
101+
hub_content_name=model_name,
102+
hub_content_version=model_version,
103+
hub_content_type=HubContentType.MODEL,
104+
hub_name=self.hub_name,
105+
)
80106

81-
return hub_info
107+
def delete(self) -> None:
108+
"""Deletes this Curated Hub"""
109+
return self._session.delete_hub(self.hub_name)

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 0 additions & 203 deletions
This file was deleted.

src/sagemaker/jumpstart/session_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,45 @@ def get_model_id_version_from_training_job(
208208
)
209209

210210
return model_id, model_version
211+
212+
213+
def generate_default_hub_bucket_name(
214+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
215+
) -> str:
216+
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
217+
218+
Returns:
219+
str: The name of the default bucket. If the name was not explicitly specified through
220+
the Session or sagemaker_config, the bucket will take the form:
221+
``sagemaker-hubs-{region}-{AWS account ID}``.
222+
"""
223+
224+
region: str = sagemaker_session.boto_region_name
225+
account_id: str = sagemaker_session.account_id()
226+
227+
# TODO: Validate and fast fail
228+
229+
return f"sagemaker-hubs-{region}-{account_id}"
230+
231+
232+
def create_hub_bucket_if_it_does_not_exist(
233+
bucket_name: str = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
234+
):
235+
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
236+
This function will create the s3 bucket if it does not exist.
237+
238+
Returns:
239+
str: The name of the default bucket. Takes the form:
240+
``sagemaker-hubs-{region}-{AWS account ID}``.
241+
"""
242+
243+
region: str = sagemaker_session.boto_region_name
244+
if bucket_name is None:
245+
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
246+
247+
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
248+
bucket_name=bucket_name,
249+
region=region,
250+
)
251+
252+
return bucket_name

0 commit comments

Comments
 (0)