Skip to content

Commit 4799ccb

Browse files
committed
Refactor hub related stuff from session utils
1 parent a4c67e8 commit 4799ccb

File tree

3 files changed

+45
-47
lines changed

3 files changed

+45
-47
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@
1515

1616

1717
from typing import Any, Dict, Optional
18-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19-
2018
from sagemaker.session import Session
21-
19+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
20+
from sagemaker.jumpstart.curated_hub import utils as hub_utils
2221
from sagemaker.jumpstart.curated_hub.types import (
2322
DescribeHubResponse,
2423
HubContentType,
2524
DescribeHubContentsResponse,
2625
)
27-
import sagemaker.jumpstart.session_utils as session_utils
2826

2927

3028
class CuratedHub:
@@ -53,7 +51,7 @@ def create(
5351
) -> Dict[str, str]:
5452
"""Creates a hub with the given description"""
5553

56-
bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist(
54+
bucket_name = hub_utils.create_hub_bucket_if_it_does_not_exist(
5755
bucket_name, self._sagemaker_session
5856
)
5957

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,45 @@ def generate_hub_arn_for_estimator_init_kwargs(
110110
else:
111111
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
112112
return hub_arn
113+
114+
115+
def generate_default_hub_bucket_name(
116+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
117+
) -> str:
118+
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
119+
120+
Returns:
121+
str: The name of the default bucket. If the name was not explicitly specified through
122+
the Session or sagemaker_config, the bucket will take the form:
123+
``sagemaker-hubs-{region}-{AWS account ID}``.
124+
"""
125+
126+
region: str = sagemaker_session.boto_region_name
127+
account_id: str = sagemaker_session.account_id()
128+
129+
# TODO: Validate and fast fail
130+
131+
return f"sagemaker-hubs-{region}-{account_id}"
132+
133+
134+
def create_hub_bucket_if_it_does_not_exist(
135+
bucket_name: Optional[str] = None,
136+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
137+
) -> str:
138+
"""Creates the default SageMaker Hub bucket if it does not exist.
139+
140+
Returns:
141+
str: The name of the default bucket. Takes the form:
142+
``sagemaker-hubs-{region}-{AWS account ID}``.
143+
"""
144+
145+
region: str = sagemaker_session.boto_region_name
146+
if bucket_name is None:
147+
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
148+
149+
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
150+
bucket_name=bucket_name,
151+
region=region,
152+
)
153+
154+
return bucket_name

src/sagemaker/jumpstart/session_utils.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -208,45 +208,3 @@ def get_model_id_version_from_training_job(
208208
)
209209

210210
return model_id, model_version
211-
212-
213-
def generate_default_hub_bucket_name(
214-
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
215-
) -> str:
216-
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
217-
218-
Returns:
219-
str: The name of the default bucket. If the name was not explicitly specified through
220-
the Session or sagemaker_config, the bucket will take the form:
221-
``sagemaker-hubs-{region}-{AWS account ID}``.
222-
"""
223-
224-
region: str = sagemaker_session.boto_region_name
225-
account_id: str = sagemaker_session.account_id()
226-
227-
# TODO: Validate and fast fail
228-
229-
return f"sagemaker-hubs-{region}-{account_id}"
230-
231-
232-
def create_hub_bucket_if_it_does_not_exist(
233-
bucket_name: Optional[str] = None,
234-
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
235-
) -> str:
236-
"""Creates the default SageMaker Hub bucket if it does not exist.
237-
238-
Returns:
239-
str: The name of the default bucket. Takes the form:
240-
``sagemaker-hubs-{region}-{AWS account ID}``.
241-
"""
242-
243-
region: str = sagemaker_session.boto_region_name
244-
if bucket_name is None:
245-
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
246-
247-
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
248-
bucket_name=bucket_name,
249-
region=region,
250-
)
251-
252-
return bucket_name

0 commit comments

Comments
 (0)