Skip to content

Commit d08fa11

Browse files
committed
fix: formatting/style/lint errors
1 parent 3712df6 commit d08fa11

File tree

9 files changed

+193
-128
lines changed

9 files changed

+193
-128
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
"""This module defines the JumpStartModelsCache class."""
14+
from __future__ import absolute_import
1315
import datetime
1416
from typing import List, Optional
17+
import json
18+
import boto3
19+
import semantic_version
1520
from sagemaker.jumpstart.types import (
1621
JumpStartCachedS3ContentKey,
1722
JumpStartCachedS3ContentValue,
1823
JumpStartModelHeader,
1924
JumpStartModelSpecs,
20-
JumpStartModelSpecs,
2125
JumpStartS3FileType,
2226
JumpStartVersionedModelId,
2327
)
2428
from sagemaker.jumpstart import utils
2529
from sagemaker.utilities.cache import LRUCache
26-
import boto3
27-
import json
28-
import semantic_version
29-
3030

3131
DEFAULT_REGION_NAME = boto3.session.Session().region_name
3232

@@ -41,6 +41,7 @@
4141

4242
class JumpStartModelsCache:
4343
"""Class that implements a cache for JumpStart models manifests and specs.
44+
4445
The manifest and specs associated with JumpStart models provide the information necessary
4546
for launching JumpStart models from the SageMaker SDK.
4647
"""
@@ -62,15 +63,16 @@ def __init__(
6263
Args:
6364
region (Optional[str]): AWS region to associate with cache. Default: region associated
6465
with botocore session.
65-
max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20.
66-
s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3
67-
cache before invalidation. Default: 6 hours.
66+
max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
67+
Default: 20.
68+
s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69+
s3 cache before invalidation. Default: 6 hours.
6870
max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
6971
semantic version cache. Default: 20.
70-
semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold
71-
items in semantic version cache before invalidation. Default: 6 hours.
72-
bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content
73-
bucket for region.
72+
semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73+
hold items in semantic version cache before invalidation. Default: 6 hours.
74+
bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75+
content bucket for region.
7476
"""
7577

7678
self._region = region
@@ -120,15 +122,16 @@ def get_bucket(self) -> None:
120122
return self._bucket
121123

122124
def _get_manifest_key_from_model_id_semantic_version(
123-
self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId]
125+
self,
126+
key: JumpStartVersionedModelId,
127+
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
124128
) -> JumpStartVersionedModelId:
125-
"""Return model id and version in manifest that matches semantic version/id
126-
from customer request.
129+
"""Return model id and version in manifest that matches semantic version/id.
127130
128131
Args:
129132
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
130-
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached
131-
model id/version.
133+
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134+
old cached model id/version.
132135
133136
Raises:
134137
KeyError: If the semantic version is not found in the manifest.
@@ -158,42 +161,42 @@ def _get_manifest_key_from_model_id_semantic_version(
158161
sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker)
159162
if sm_compatible_model_version is not None:
160163
return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version))
161-
else:
162-
versions_incompatible_with_sagemaker = [
163-
semantic_version.Version(header.version)
164+
165+
versions_incompatible_with_sagemaker = [
166+
semantic_version.Version(header.version)
167+
for _, header in manifest.items()
168+
if header.model_id == model_id
169+
]
170+
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
171+
if sm_incompatible_model_version is not None:
172+
model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version)
173+
sm_version_to_use = [
174+
header.min_version
164175
for _, header in manifest.items()
165176
if header.model_id == model_id
177+
and header.version == model_version_to_use_incompatible_with_sagemaker
166178
]
167-
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
168-
if sm_incompatible_model_version is not None:
169-
model_version_to_use_incompatible_with_sagemaker = str(
170-
sm_incompatible_model_version
171-
)
172-
sm_version_to_use = [
173-
header.min_version
174-
for _, header in manifest.items()
175-
if header.model_id == model_id
176-
and header.version == model_version_to_use_incompatible_with_sagemaker
177-
]
178-
assert len(sm_version_to_use) == 1
179-
sm_version_to_use = sm_version_to_use[0]
180-
181-
error_msg = (
182-
f"Unable to find model manifest for {model_id} with version {version} compatible with your SageMaker version ({sm_version}). "
183-
f"Consider upgrading your SageMaker library to at least version {sm_version_to_use} so you can use version "
184-
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
185-
)
186-
raise KeyError(error_msg)
187-
else:
188-
error_msg = f"Unable to find model manifest for {model_id} with version {version}"
189-
raise KeyError(error_msg)
179+
assert len(sm_version_to_use) == 1
180+
sm_version_to_use = sm_version_to_use[0]
181+
182+
error_msg = (
183+
f"Unable to find model manifest for {model_id} with version {version} "
184+
f"compatible with your SageMaker version ({sm_version}). "
185+
f"Consider upgrading your SageMaker library to at least version "
186+
f"{sm_version_to_use} so you can use version "
187+
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
188+
)
189+
raise KeyError(error_msg)
190+
error_msg = f"Unable to find model manifest for {model_id} with version {version}"
191+
raise KeyError(error_msg)
190192

191193
def _get_file_from_s3(
192194
self,
193195
key: JumpStartCachedS3ContentKey,
194196
value: Optional[JumpStartCachedS3ContentValue],
195197
) -> JumpStartCachedS3ContentValue:
196198
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
199+
197200
If a manifest file is being fetched, we only download the object if the md5 hash in
198201
``head_object`` does not match the current md5 hash for the stored value. This prevents
199202
unnecessarily downloading the full manifest when it hasn't changed.
@@ -228,18 +231,18 @@ def _get_file_from_s3(
228231
raise RuntimeError(f"Bad value for key: {key}")
229232

230233
def get_header(
231-
self, model_id: str, semantic_version: Optional[str] = None
234+
self, model_id: str, semantic_version_str: Optional[str] = None
232235
) -> List[JumpStartModelHeader]:
233236
"""Return list of headers for a given JumpStart model id and semantic version.
234237
235238
Args:
236239
model_id (str): model id for which to get a header.
237-
semantic_version (Optional[str]): The semantic version for which to get a header.
238-
If None, the highest compatible version is returned.
240+
semantic_version_str (Optional[str]): The semantic version for which to get a
241+
header. If None, the highest compatible version is returned.
239242
"""
240243

241244
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
242-
JumpStartVersionedModelId(model_id, semantic_version)
245+
JumpStartVersionedModelId(model_id, semantic_version_str)
243246
)
244247
manifest = self._s3_cache.get(
245248
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
@@ -258,16 +261,17 @@ def get_header(
258261
return self.get_header(model_id, semantic_version)
259262

260263
def get_specs(
261-
self, model_id: str, semantic_version: Optional[str] = None
264+
self, model_id: str, semantic_version_str: Optional[str] = None
262265
) -> JumpStartModelSpecs:
263266
"""Return specs for a given JumpStart model id and semantic version.
264267
265268
Args:
266269
model_id (str): model id for which to get specs.
267-
semantic_version (Optional[str]): The semantic version for which to get specs.
268-
If None, the highest compatible version is returned.
270+
semantic_version_str (Optional[str]): The semantic version for which to get
271+
specs. If None, the highest compatible version is returned.
269272
"""
270-
header = self.get_header(model_id, semantic_version)
273+
274+
header = self.get_header(model_id, semantic_version_str)
271275
spec_key = header.spec_key
272276
return self._s3_cache.get(
273277
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
"""This module stores constants related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
1315
from typing import Set
1416
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1517

src/sagemaker/jumpstart/types.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,25 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
"""This module stores types related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
1315
from enum import Enum
1416
from typing import Any, Dict, List, Optional, Union
1517

1618

1719
class JumpStartDataHolderType:
18-
"""Base class for many JumpStart types. Allows objects to be added to dicts and sets,
20+
"""Base class for many JumpStart types.
21+
22+
Allows objects to be added to dicts and sets,
1923
and improves string representation. This class allows different objects with the same
2024
attributes and types to have equality.
2125
"""
2226

27+
__slots__: List[str] = []
28+
2329
def __eq__(self, other: Any) -> bool:
24-
"""Returns True if other object is of the same type
25-
and has all attributes equal."""
30+
"""Returns True if ``other`` is of the same type and has all attributes equal."""
31+
2632
if not isinstance(other, type(self)):
2733
return False
2834
for attribute in self.__slots__:
@@ -31,23 +37,30 @@ def __eq__(self, other: Any) -> bool:
3137
return True
3238

3339
def __hash__(self) -> int:
34-
"""Makes hash of object by first mapping to unique tuple, which then
35-
gets hashed.
40+
"""Makes hash of object.
41+
42+
Maps object to unique tuple, which then gets hashed.
3643
"""
44+
3745
return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__]))
3846

3947
def __str__(self) -> str:
4048
"""Returns string representation of object. Example:
41-
"JumpStartLaunchedRegionInfo: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}"
49+
50+
"JumpStartLaunchedRegionInfo:
51+
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
4252
"""
53+
4354
att_dict = {att: getattr(self, att) for att in self.__slots__}
4455
return f"{type(self).__name__}: {str(att_dict)}"
4556

4657
def __repr__(self) -> str:
47-
"""This is often called instead of __str__ and is the official string representation
48-
of an object, typicaly used for debugging. Example:
49-
"JumpStartLaunchedRegionInfo at 0x7f664529efa0: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}"
58+
"""Returns ``__repr__`` string of object. Example:
59+
60+
"JumpStartLaunchedRegionInfo at 0x7f664529efa0:
61+
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
5062
"""
63+
5164
att_dict = {att: getattr(self, att) for att in self.__slots__}
5265
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
5366

src/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
"""This module contains utilities related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
1315
from typing import Dict, List
1416

1517
import semantic_version
@@ -31,11 +33,11 @@ def get_jumpstart_launched_regions_string() -> str:
3133
return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions."
3234

3335
formatted_launched_regions_list = []
34-
for i in range(len(sorted_regions)):
36+
for i, region in enumerate(sorted_regions):
3537
region_prefix = ""
3638
if i == len(sorted_regions) - 1:
3739
region_prefix = "and "
38-
formatted_launched_regions_list.append(region_prefix + sorted_regions[i])
40+
formatted_launched_regions_list.append(region_prefix + region)
3941
formatted_launched_regions_str = ", ".join(formatted_launched_regions_list)
4042
return f"JumpStart is available in {formatted_launched_regions_str} regions."
4143

@@ -59,8 +61,11 @@ def get_jumpstart_content_bucket(region: str) -> str:
5961
def get_formatted_manifest(
6062
manifest: List[Dict],
6163
) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]:
62-
"""Returns formatted manifest dictionary from raw manifest. Keys are JumpStartVersionedModelId objects,
63-
values are JumpStartModelHeader objects."""
64+
"""Returns formatted manifest dictionary from raw manifest.
65+
66+
Keys are JumpStartVersionedModelId objects, values are
67+
``JumpStartModelHeader`` objects.
68+
"""
6469
manifest_dict = {}
6570
for header in manifest:
6671
header_obj = JumpStartModelHeader(header)
@@ -71,8 +76,10 @@ def get_formatted_manifest(
7176

7277

7378
def get_sagemaker_version() -> str:
74-
"""Returns sagemaker library version by reading __version__ variable
75-
in module. In order to maintain compatibility with the ``semantic_version``
79+
"""Returns sagemaker library version.
80+
81+
Function reads ``__version__`` variable in ``sagemaker`` module.
82+
In order to maintain compatibility with the ``semantic_version``
7683
library, versions with fewer than 2, or more than 3, periods are rejected.
7784
All versions that cannot be parsed with ``semantic_version`` are also
7885
rejected.

0 commit comments

Comments
 (0)