Skip to content

Commit 85c65d5

Browse files
committed
change: improve docstrings, typos
1 parent 3bde038 commit 85c65d5

File tree

4 files changed

+77
-20
lines changed

4 files changed

+77
-20
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _get_manifest_key_from_model_id_semantic_version(
164164

165165
manifest = self._s3_cache.get(
166166
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
167-
).formatted_file_content
167+
).formatted_content
168168

169169
sm_version = utils.get_sagemaker_version()
170170

@@ -199,7 +199,9 @@ def _get_manifest_key_from_model_id_semantic_version(
199199
if header.model_id == model_id
200200
and header.version == model_version_to_use_incompatible_with_sagemaker
201201
]
202-
assert len(sm_version_to_use) == 1 # ``manifest`` dict should already enforce this
202+
if len(sm_version_to_use) != 1:
203+
# ``manifest`` dict should already enforce this
204+
raise RuntimeError("Found more than one incompatible SageMaker version to use.")
203205
sm_version_to_use = sm_version_to_use[0]
204206

205207
error_msg = (
@@ -242,14 +244,14 @@ def _get_file_from_s3(
242244
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
243245
etag = response["ETag"]
244246
return JumpStartCachedS3ContentValue(
245-
formatted_file_content=utils.get_formatted_manifest(formatted_body),
247+
formatted_content=utils.get_formatted_manifest(formatted_body),
246248
md5_hash=etag,
247249
)
248250
if file_type == JumpStartS3FileType.SPECS:
249251
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
250252
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
251253
return JumpStartCachedS3ContentValue(
252-
formatted_file_content=JumpStartModelSpecs(formatted_body)
254+
formatted_content=JumpStartModelSpecs(formatted_body)
253255
)
254256
raise ValueError(
255257
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
@@ -260,7 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
260262

261263
return self._s3_cache.get(
262264
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
263-
).formatted_file_content.values()
265+
).formatted_content.values()
264266

265267
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
266268
"""Return header for a given JumpStart model id and semantic version.
@@ -295,7 +297,7 @@ def _get_header_impl(
295297
)
296298
manifest = self._s3_cache.get(
297299
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
298-
).formatted_file_content
300+
).formatted_content
299301
try:
300302
return manifest[versioned_model_id]
301303
except KeyError:
@@ -317,7 +319,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
317319
spec_key = header.spec_key
318320
return self._s3_cache.get(
319321
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
320-
).formatted_file_content
322+
).formatted_content
321323

322324
def clear(self) -> None:
323325
"""Clears the model id/version and s3 cache."""

src/sagemaker/jumpstart/types.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ class JumpStartDataHolderType:
2828
__slots__: List[str] = []
2929

3030
def __eq__(self, other: Any) -> bool:
31-
"""Returns True if ``other`` is of the same type and has all attributes equal."""
31+
"""Returns True if ``other`` is of the same type and has all attributes equal.
32+
33+
Args:
34+
other (Any): Other object to which to compare this object.
35+
"""
3236

3337
if not isinstance(other, type(self)):
3438
return False
@@ -83,6 +87,12 @@ class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
8387
__slots__ = ["content_bucket", "region_name"]
8488

8589
def __init__(self, content_bucket: str, region_name: str):
90+
"""Instantiates JumpStartLaunchedRegionInfo object.
91+
92+
Args:
93+
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
94+
region_name (str): Name of JumpStart launched region.
95+
"""
8696
self.content_bucket = content_bucket
8797
self.region_name = region_name
8898

@@ -93,7 +103,11 @@ class JumpStartModelHeader(JumpStartDataHolderType):
93103
__slots__ = ["model_id", "version", "min_version", "spec_key"]
94104

95105
def __init__(self, header: Dict[str, str]):
96-
"""Initializes a JumpStartModelHeader object from its json representation."""
106+
"""Initializes a JumpStartModelHeader object from its json representation.
107+
108+
Args:
109+
header (Dict[str, str]): Dictionary representation of header.
110+
"""
97111
self.from_json(header)
98112

99113
def to_json(self) -> Dict[str, str]:
@@ -102,7 +116,11 @@ def to_json(self) -> Dict[str, str]:
102116
return json_obj
103117

104118
def from_json(self, json_obj: Dict[str, str]) -> None:
105-
"""Sets fields in object based on json of header."""
119+
"""Sets fields in object based on json of header.
120+
121+
Args:
122+
json_obj (Dict[str, str]): Dictionary representation of header.
123+
"""
106124
self.model_id: str = json_obj["model_id"]
107125
self.version: str = json_obj["version"]
108126
self.min_version: str = json_obj["min_version"]
@@ -119,11 +137,19 @@ class JumpStartECRSpecs(JumpStartDataHolderType):
119137
}
120138

121139
def __init__(self, spec: Dict[str, Any]):
122-
"""Initializes a JumpStartECRSpecs object from its json representation."""
140+
"""Initializes a JumpStartECRSpecs object from its json representation.
141+
142+
Args:
143+
spec (Dict[str, Any]): Dictionary representation of spec.
144+
"""
123145
self.from_json(spec)
124146

125147
def from_json(self, json_obj: Dict[str, Any]) -> None:
126-
"""Sets fields in object based on json."""
148+
"""Sets fields in object based on json.
149+
150+
Args:
151+
json_obj (Dict[str, Any]): Dictionary representation of spec.
152+
"""
127153

128154
self.framework = json_obj["framework"]
129155
self.framework_version = json_obj["framework_version"]
@@ -154,11 +180,19 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
154180
]
155181

156182
def __init__(self, spec: Dict[str, Any]):
157-
"""Initializes a JumpStartModelSpecs object from its json representation."""
183+
"""Initializes a JumpStartModelSpecs object from its json representation.
184+
185+
Args:
186+
spec (Dict[str, Any]): Dictionary representation of spec.
187+
"""
158188
self.from_json(spec)
159189

160190
def from_json(self, json_obj: Dict[str, Any]) -> None:
161-
"""Sets fields in object based on json of header."""
191+
"""Sets fields in object based on json of header.
192+
193+
Args:
194+
json_obj (Dict[str, Any]): Dictionary representation of spec.
195+
"""
162196
self.model_id: str = json_obj["model_id"]
163197
self.version: str = json_obj["version"]
164198
self.min_sdk_version: str = json_obj["min_sdk_version"]
@@ -201,6 +235,12 @@ def __init__(
201235
model_id: str,
202236
version: str,
203237
) -> None:
238+
"""Instantiates JumpStartVersionedModelId object.
239+
240+
Args:
241+
model_id (str): JumpStart model id.
242+
version (str): JumpStart model version.
243+
"""
204244
self.model_id = model_id
205245
self.version = version
206246

@@ -215,22 +255,37 @@ def __init__(
215255
file_type: JumpStartS3FileType,
216256
s3_key: str,
217257
) -> None:
258+
"""Instantiates JumpStartCachedS3ContentKey object.
259+
260+
Args:
261+
file_type (JumpStartS3FileType): JumpStart file type.
262+
s3_key (str): object key in s3.
263+
"""
218264
self.file_type = file_type
219265
self.s3_key = s3_key
220266

221267

222268
class JumpStartCachedS3ContentValue(JumpStartDataHolderType):
223269
"""Data class for the s3 cached content values."""
224270

225-
__slots__ = ["formatted_file_content", "md5_hash"]
271+
__slots__ = ["formatted_content", "md5_hash"]
226272

227273
def __init__(
228274
self,
229-
formatted_file_content: Union[
275+
formatted_content: Union[
230276
Dict[JumpStartVersionedModelId, JumpStartModelHeader],
231277
List[JumpStartModelSpecs],
232278
],
233279
md5_hash: Optional[str] = None,
234280
) -> None:
235-
self.formatted_file_content = formatted_file_content
281+
"""Instantiates JumpStartCachedS3ContentValue object.
282+
283+
Args:
284+
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],
285+
List[JumpStartModelSpecs]]):
286+
Formatted content for model specs and mappings from
287+
versioned model ids to specs.
288+
md5_hash (str): md5_hash for stored file content from s3.
289+
"""
290+
self.formatted_content = formatted_content
236291
self.md5_hash = md5_hash

src/sagemaker/utilities/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
persist before being invalidated.
6262
retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache
6363
keys and current values to new values. This function must have kwarg arguments
64-
``key`` and and ``value``. This function is called as a fallback when the key
64+
``key`` and ``value``. This function is called as a fallback when the key
6565
is not found in the cache, or a key has expired.
6666
6767
"""

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def patched_get_file_from_s3(
117117
if filetype == JumpStartS3FileType.MANIFEST:
118118

119119
return JumpStartCachedS3ContentValue(
120-
formatted_file_content=get_formatted_manifest(BASE_MANIFEST)
120+
formatted_content=get_formatted_manifest(BASE_MANIFEST)
121121
)
122122

123123
if filetype == JumpStartS3FileType.SPECS:
124124
_, model_id, specs_version = s3_key.split("/")
125125
version = specs_version.replace("specs_v", "").replace(".json", "")
126126
return JumpStartCachedS3ContentValue(
127-
formatted_file_content=get_spec_from_base_spec(model_id, version)
127+
formatted_content=get_spec_from_base_spec(model_id, version)
128128
)
129129

130130
raise ValueError(f"Bad value for filetype: {filetype}")

0 commit comments

Comments
 (0)