Skip to content

Commit b0ce624

Browse files
committed
remove region dependency for curatedhub
1 parent 5f24036 commit b0ce624

File tree

4 files changed

+13
-26
lines changed

4 files changed

+13
-26
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,17 @@ class CuratedHub:
3131
def __init__(
3232
self,
3333
hub_name: str,
34-
region: str,
3534
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3635
):
3736
"""Instantiates a SageMaker ``CuratedHub``.
3837
3938
Args:
4039
hub_name (str): The name of the Hub to create.
41-
region (str): The region in which the CuratedHub is in.
4240
sagemaker_session (sagemaker.session.Session): A SageMaker Session
4341
object, used for SageMaker interactions.
4442
"""
4543
self.hub_name = hub_name
46-
if sagemaker_session.boto_region_name != region:
47-
raise ValueError(
48-
f"Cannot have conflicting regions for region=[{region}] and ",
49-
f"sagemaker_session region=[{str(sagemaker_session.boto_region_name)}].",
50-
)
51-
self.region = region
44+
self.region = sagemaker_session.boto_region_name
5245
self._sagemaker_session = sagemaker_session
5346

5447
def create(
@@ -75,9 +68,11 @@ def create(
7568
def describe(self) -> DescribeHubResponse:
7669
"""Returns descriptive information about the Hub"""
7770

78-
hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
71+
hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub(
72+
hub_name=self.hub_name
73+
)
7974

80-
return DescribeHubResponse(hub_description)
75+
return hub_description
8176

8277
def list_models(self, **kwargs) -> Dict[str, Any]:
8378
"""Lists the models in this Curated Hub

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,12 @@ def sagemaker_session():
3434

3535

3636
def test_instantiates(sagemaker_session):
37-
hub = CuratedHub(hub_name=HUB_NAME, region=REGION, sagemaker_session=sagemaker_session)
37+
hub = CuratedHub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)
3838
assert hub.hub_name == HUB_NAME
3939
assert hub.region == "us-east-1"
4040
assert hub._sagemaker_session == sagemaker_session
4141

4242

43-
def test_instantiates_handles_conflicting_regions(sagemaker_session):
44-
conflicting_region = "us-east-2"
45-
46-
with pytest.raises(ValueError):
47-
CuratedHub(
48-
hub_name=HUB_NAME, region=conflicting_region, sagemaker_session=sagemaker_session
49-
)
50-
51-
5243
@pytest.mark.parametrize(
5344
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
5445
[
@@ -74,7 +65,7 @@ def test_create_with_no_bucket_name(
7465
):
7566
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
7667
sagemaker_session.create_hub = Mock(return_value=create_hub)
77-
hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session)
68+
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
7869
request = {
7970
"hub_name": hub_name,
8071
"hub_description": hub_description,
@@ -119,7 +110,7 @@ def test_create_with_bucket_name(
119110
):
120111
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
121112
sagemaker_session.create_hub = Mock(return_value=create_hub)
122-
hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session)
113+
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
123114
request = {
124115
"hub_name": hub_name,
125116
"hub_description": hub_description,

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client):
252252
@patch("boto3.client")
253253
def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
254254
cache = JumpStartModelsCache(
255-
s3_bucket_name="some_bucket", region="some_region", manifest_file_s3_key="some_key"
255+
s3_bucket_name="some_bucket", region="us-west-2", manifest_file_s3_key="some_key"
256256
)
257257

258258
cache.clear = MagicMock()
259259
cache.set_s3_bucket_name("some_bucket")
260260
cache.clear.assert_not_called()
261261
cache.clear.reset_mock()
262-
cache.set_region("some_region")
262+
cache.set_region("us-west-2")
263263
cache.clear.assert_not_called()
264264
cache.clear.reset_mock()
265265
cache.set_manifest_file_s3_key("some_key")
@@ -270,7 +270,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
270270
cache.set_s3_bucket_name("some_bucket1")
271271
cache.clear.assert_called_once()
272272
cache.clear.reset_mock()
273-
cache.set_region("some_region1")
273+
cache.set_region("us-east-1")
274274
cache.clear.assert_called_once()
275275
cache.clear.reset_mock()
276276
cache.set_manifest_file_s3_key("some_key1")

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
JumpStartModelSpecs,
2828
JumpStartS3FileType,
2929
JumpStartModelHeader,
30+
HubType,
3031
HubContentType,
3132
)
3233

@@ -211,7 +212,7 @@ def patched_retrieval_function(
211212
)
212213

213214
# TODO: Implement
214-
if datatype == HubContentType.HUB:
215+
if datatype == HubType.HUB:
215216
return None
216217

217218
raise ValueError(f"Bad value for filetype: {datatype}")

0 commit comments

Comments
 (0)