Skip to content

Commit b600dd1

Browse files
committed
Add curated_hub utils unit tests
1 parent 086bf92 commit b600dd1

File tree

8 files changed

+53
-19
lines changed

8 files changed

+53
-19
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import botocore
2222
from packaging.version import Version
2323
from packaging.specifiers import SpecifierSet, InvalidSpecifier
24+
from sagemaker.utilities.cache import LRUCache
2425
from sagemaker.jumpstart.constants import (
2526
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
2627
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -36,6 +37,7 @@
3637
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
3738
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
3839
)
40+
from sagemaker.jumpstart import utils
3941
from sagemaker.jumpstart.types import (
4042
JumpStartCachedContentKey,
4143
JumpStartCachedContentValue,
@@ -45,14 +47,12 @@
4547
JumpStartVersionedModelId,
4648
)
4749
from sagemaker.jumpstart.curated_hub.types import (
48-
HubContentType,
4950
DescribeHubResponse,
5051
DescribeHubContentsResponse,
52+
HubContentType,
5153
)
52-
from sagemaker.jumpstart import utils
5354
from sagemaker.jumpstart.curated_hub import utils as hub_utils
5455
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
55-
from sagemaker.utilities.cache import LRUCache
5656

5757

5858
class JumpStartModelsCache:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from typing import Any, Dict, Optional
1818
from sagemaker.session import Session
1919
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
20-
from sagemaker.jumpstart.curated_hub import utils as hub_utils
20+
from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist
2121
from sagemaker.jumpstart.curated_hub.types import (
2222
DescribeHubResponse,
23-
HubContentType,
2423
DescribeHubContentsResponse,
24+
HubContentType,
2525
)
2626

2727

@@ -51,9 +51,7 @@ def create(
5151
) -> Dict[str, str]:
5252
"""Creates a hub with the given description"""
5353

54-
bucket_name = hub_utils.create_hub_bucket_if_it_does_not_exist(
55-
bucket_name, self._sagemaker_session
56-
)
54+
bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session)
5755

5856
return self._sagemaker_session.create_hub(
5957
hub_name=self.hub_name,

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from enum import Enum
1616
from typing import Any, Dict, List, Optional
1717

18+
1819
from sagemaker.jumpstart.types import JumpStartDataHolderType
1920

2021

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import re
1616
from typing import Optional
1717
from sagemaker.session import Session
18-
from sagemaker.jumpstart import constants
1918
from sagemaker.utils import aws_partition
19+
from sagemaker.jumpstart import constants
2020
from sagemaker.jumpstart.curated_hub.types import (
2121
HubContentType,
2222
HubArnExtractedInfo,

src/sagemaker/jumpstart/session_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from __future__ import absolute_import
1616

1717
from typing import Optional, Tuple
18-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1918

20-
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
2119
from sagemaker.session import Session
2220
from sagemaker.utils import aws_partition
21+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
22+
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
2323

2424

2525
def get_model_id_version_from_endpoint(

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18+
from sagemaker.session import Session
1819
from sagemaker.utils import get_instance_type_family, format_tags, Tags
20+
from sagemaker.enums import EndpointType
1921
from sagemaker.model_metrics import ModelMetrics
2022
from sagemaker.metadata_properties import MetadataProperties
2123
from sagemaker.drift_check_baselines import DriftCheckBaselines
22-
23-
from sagemaker.session import Session
2424
from sagemaker.workflow.entities import PipelineVariable
2525
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
26-
from sagemaker.enums import EndpointType
27-
from sagemaker.jumpstart.curated_hub import types as hub_types
2826

2927

3028
class JumpStartDataHolderType:
@@ -119,14 +117,17 @@ def to_json(self) -> Dict[str, Any]:
119117
return json_obj
120118

121119

120+
from sagemaker.jumpstart.curated_hub.types import HubContentType # noqa: E402
121+
122+
122123
class JumpStartS3FileType(str, Enum):
123124
"""Type of files published in JumpStart S3 distribution buckets."""
124125

125126
MANIFEST = "manifest"
126127
SPECS = "specs"
127128

128129

129-
JumpStartContentDataType = Union[JumpStartS3FileType, hub_types.HubContentType]
130+
JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType]
130131

131132

132133
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14-
from unittest.mock import Mock
15-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1614

15+
from unittest.mock import Mock
16+
from botocore.exceptions import ClientError
1717
from sagemaker.jumpstart.curated_hub import utils
18+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1819
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo
1920

2021

@@ -151,3 +152,35 @@ def test_generate_hub_arn_for_estimator_init_kwargs():
151152
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152153
== hub_arn
153154
)
155+
156+
157+
def test_generate_default_hub_bucket_name():
158+
mock_sagemaker_session = Mock()
159+
mock_sagemaker_session.account_id.return_value = "123456789123"
160+
mock_sagemaker_session.boto_region_name = "us-east-1"
161+
162+
assert (
163+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
164+
== "sagemaker-hubs-us-east-1-123456789123"
165+
)
166+
167+
168+
def test_create_hub_bucket_if_it_does_not_exist():
169+
mock_sagemaker_session = Mock()
170+
mock_sagemaker_session.account_id.return_value = "123456789123"
171+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
172+
"Account": "123456789123"
173+
}
174+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
175+
mock_sagemaker_session.boto_region_name = "us-east-1"
176+
error = ClientError(
177+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
178+
operation_name="foo",
179+
)
180+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
181+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
182+
sagemaker_session=mock_sagemaker_session
183+
)
184+
185+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
186+
assert created_hub_bucket_name == bucket_name

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubContentType,
2625
JumpStartCachedContentKey,
2726
JumpStartCachedContentValue,
2827
JumpStartModelSpecs,
2928
JumpStartS3FileType,
3029
JumpStartModelHeader,
3130
)
31+
3232
from sagemaker.jumpstart.utils import get_formatted_manifest
3333
from tests.unit.sagemaker.jumpstart.constants import (
3434
PROTOTYPICAL_MODEL_SPECS_DICT,
@@ -37,6 +37,7 @@
3737
BASE_HEADER,
3838
SPECIAL_MODEL_SPECS_DICT,
3939
)
40+
from sagemaker.jumpstart.curated_hub.types import HubContentType
4041

4142

4243
def get_header_from_base_header(

0 commit comments

Comments
 (0)