Skip to content

Commit 0e5d5b4

Browse files
committed
fix: formatter
1 parent 77d7b50 commit 0e5d5b4

File tree

6 files changed

+184
-100
lines changed

6 files changed

+184
-100
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
HubContentDependencyType,
2222
S3ObjectLocation,
2323
)
24-
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
24+
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import (
25+
PublicModelDataAccessor,
26+
)
2527
from sagemaker.jumpstart.types import JumpStartModelSpecs
2628

2729

@@ -64,7 +66,9 @@ def generate_file_infos_from_model_specs(
6466
)
6567
files = []
6668
for dependency in HubContentDependencyType:
67-
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
69+
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(
70+
dependency
71+
)
6872
location_type = "prefix" if location.key.endswith("/") else "object"
6973

7074
if location_type == "prefix":
@@ -89,5 +93,7 @@ def generate_file_infos_from_model_specs(
8993
response = s3_client.head_object(**parameters)
9094
size = response.get("ContentLength")
9195
last_updated = response.get("LastModified")
92-
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency))
96+
files.append(
97+
FileInfo(location.bucket, location.key, size, last_updated, dependency)
98+
)
9399
return files

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def training_script_s3_reference(self):
7373
@property
7474
def default_training_dataset_s3_reference(self):
7575
"""Retrieves s3 reference for s3 directory containing model training datasets"""
76-
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
76+
return S3ObjectLocation(
77+
self._get_bucket_name(), self.__get_training_dataset_prefix()
78+
)
7779

7880
@property
7981
def demo_notebook_s3_reference(self):

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515
from concurrent import futures
16-
from functools import lru_cache
16+
from functools import lru_cache
1717
from datetime import datetime
1818
import json
1919
import traceback
@@ -37,7 +37,10 @@
3737
from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequestFactory
3838
from sagemaker.jumpstart.enums import JumpStartScriptScope
3939
from sagemaker.session import Session
40-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
40+
from sagemaker.jumpstart.constants import (
41+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
42+
JUMPSTART_LOGGER,
43+
)
4144
from sagemaker.jumpstart.types import (
4245
DescribeHubResponse,
4346
DescribeHubContentsResponse,
@@ -61,6 +64,7 @@
6164
)
6265
from sagemaker.utils import TagsDict
6366

67+
6468
class CuratedHub:
6569
"""Class for creating and managing a curated JumpStart hub"""
6670

@@ -98,7 +102,9 @@ def _fetch_hub_bucket_name(self) -> str:
98102
if hub_output_location:
99103
location = create_s3_object_reference_from_uri(hub_output_location)
100104
return location.bucket
101-
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
105+
default_bucket_name = generate_default_hub_bucket_name(
106+
self._sagemaker_session
107+
)
102108
JUMPSTART_LOGGER.warning(
103109
"There is not a Hub bucket associated with %s. Using %s",
104110
self.hub_name,
@@ -118,7 +124,9 @@ def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> N
118124
"""Generates an ``S3ObjectLocation`` given a Hub name."""
119125
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
120126
curr_timestamp = datetime.now().timestamp()
121-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
127+
return S3ObjectLocation(
128+
bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}"
129+
)
122130

123131
def create(
124132
self,
@@ -151,29 +159,30 @@ def describe(self) -> DescribeHubResponse:
151159

152160
return hub_description
153161

154-
155162
def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]:
156163
"""Lists the models in this Curated Hub.
157164
158-
This function caches the models in local memory
165+
This function caches the models in local memory
159166
160167
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
161168
"""
162169
if clear_cache:
163170
self._list_hubs_cache = None
164171
if self._list_hubs_cache is None:
165-
hub_content_summaries = self._sagemaker_session.list_hub_contents(
166-
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
167-
)
168-
self._list_hubs_cache = hub_content_summaries
172+
hub_content_summaries = self._sagemaker_session.list_hub_contents(
173+
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
174+
)
175+
self._list_hubs_cache = hub_content_summaries
169176
return self._list_hubs_cache
170177

171178
def describe_model(
172179
self, model_name: str, model_version: str = "*"
173180
) -> DescribeHubContentsResponse:
174181
"""Returns descriptive information about the Hub Model"""
175182

176-
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
183+
hub_content_description: Dict[
184+
str, Any
185+
] = self._sagemaker_session.describe_hub_content(
177186
hub_name=self.hub_name,
178187
hub_content_name=model_name,
179188
hub_content_version=model_version,
@@ -219,10 +228,16 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
219228

220229
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
221230
hub_models = summary_list_from_list_api_response(self.list_models())
222-
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
231+
return [
232+
model
233+
for model in hub_models
234+
if get_jumpstart_model_and_version(model) is not None
235+
]
223236

224237
def _determine_models_to_sync(
225-
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, HubContentSummary]
238+
self,
239+
model_list: List[JumpStartModelInfo],
240+
models_in_hub: Dict[str, HubContentSummary],
226241
) -> List[JumpStartModelInfo]:
227242
"""Determines which models from `sync` params to sync into the CuratedHub.
228243
@@ -286,14 +301,22 @@ def sync(self, model_list: List[Dict[str, str]]):
286301
model["model_id"],
287302
model["version"],
288303
)
289-
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))
304+
model_version_list.append(
305+
JumpStartModelInfo(model["model_id"], model["version"])
306+
)
290307

291308
js_models_in_hub = self._get_jumpstart_models_in_hub()
292-
mapped_models_in_hub = {model.hub_content_name: model for model in js_models_in_hub}
309+
mapped_models_in_hub = {
310+
model.hub_content_name: model for model in js_models_in_hub
311+
}
293312

294-
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
313+
models_to_sync = self._determine_models_to_sync(
314+
model_version_list, mapped_models_in_hub
315+
)
295316
JUMPSTART_LOGGER.warning(
296-
"Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync
317+
"Syncing the following models into Hub %s: %s",
318+
self.hub_name,
319+
models_to_sync,
297320
)
298321

299322
# Delete old models?
@@ -305,7 +328,9 @@ def sync(self, model_list: List[Dict[str, str]]):
305328
thread_name_prefix="import-models-to-curated-hub",
306329
) as import_executor:
307330
for thread_num, model in enumerate(models_to_sync):
308-
task = import_executor.submit(self._sync_public_model_to_hub, model, thread_num)
331+
task = import_executor.submit(
332+
self._sync_public_model_to_hub, model, thread_num
333+
)
309334
tasks.append(task)
310335

311336
# Handle failed imports
@@ -318,7 +343,9 @@ def sync(self, model_list: List[Dict[str, str]]):
318343
{
319344
"Exception": exception,
320345
"Traceback": "".join(
321-
traceback.TracebackException.from_exception(exception).format()
346+
traceback.TracebackException.from_exception(
347+
exception
348+
).format()
322349
),
323350
}
324351
)
@@ -362,7 +389,9 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
362389
label=dest_location.key,
363390
).execute()
364391
else:
365-
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version)
392+
JUMPSTART_LOGGER.warning(
393+
"Nothing to copy for %s v%s", model.model_id, model.version
394+
)
366395

367396
# TODO: Tag model if specs say it is deprecated or training/inference
368397
# vulnerable. Update tag of HubContent ARN without version.
@@ -403,44 +432,47 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
403432
)
404433
return json.loads(response["Body"].read().decode("utf-8"))
405434

406-
407435
def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
408436
"""Scans the Hub for JumpStart models and tags the HubContent.
409-
437+
410438
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
411439
The tags that will be added are based off the specifications in the JumpStart public hub:
412440
1. "deprecated_versions" -> If the public hub model is deprecated
413441
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities
414442
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities
415443
416-
The tag value will be a list of versions in the Curated Hub that fall under those keys.
444+
The tag value will be a list of versions in the Curated Hub that fall under those keys.
417445
For example, if model_a version_a is deprecated and inference is vulnerable, the
418-
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
446+
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
419447
{"inference_vulnerable_versions": [version_a]}]
420448
421-
If models are passed in,
449+
If models are passed in,
422450
"""
423-
JUMPSTART_LOGGER.info(
424-
"Tagging models in hub: %s", self.hub_name
425-
)
451+
JUMPSTART_LOGGER.info("Tagging models in hub: %s", self.hub_name)
426452
if self._is_invalid_model_list_input(model_list):
427453
raise ValueError(
428454
"Model list should be a list of objects with values 'model_id',",
429455
"and optional 'version'.",
430456
)
431-
457+
432458
models_to_scan = model_list if model_list else self.list_models()
433-
js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None]
459+
js_models_in_hub = [
460+
model
461+
for model in models_to_scan
462+
if get_jumpstart_model_and_version(model) is not None
463+
]
434464
for model in js_models_in_hub:
435-
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions(
465+
tags_to_add: List[
466+
TagsDict
467+
] = find_unsupported_flags_for_hub_content_versions(
436468
hub_name=self.hub_name,
437469
hub_content_name=model.hub_content_name,
438470
region=self.region,
439-
session=self._sagemaker_session
471+
session=self._sagemaker_session,
440472
)
441473
tag_hub_content(
442474
hub_content_arn=model.hub_content_arn,
443475
tags=tags_to_add,
444-
session=self._sagemaker_session
476+
session=self._sagemaker_session,
445477
)
446-
JUMPSTART_LOGGER.info("Tagging complete!")
478+
JUMPSTART_LOGGER.info("Tagging complete!")

src/sagemaker/jumpstart/curated_hub/sync/request.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ class HubSyncRequest:
3131
destination: S3ObjectLocation
3232

3333
def __init__(
34-
self, files_to_copy: Generator[FileInfo, FileInfo, FileInfo], destination: S3ObjectLocation
34+
self,
35+
files_to_copy: Generator[FileInfo, FileInfo, FileInfo],
36+
destination: S3ObjectLocation,
3537
):
3638
"""Contains information required to sync data into a Hub.
3739
@@ -70,7 +72,9 @@ def __init__(
7072
# Need the file lists to be sorted for comparisons below
7173
self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key)
7274
formatted_dest_files = [self._format_dest_file(file) for file in dest_files]
73-
self.dest_files: List[FileInfo] = sorted(formatted_dest_files, key=lambda x: x.location.key)
75+
self.dest_files: List[FileInfo] = sorted(
76+
formatted_dest_files, key=lambda x: x.location.key
77+
)
7478

7579
def _format_dest_file(self, file: FileInfo) -> FileInfo:
7680
"""Strips HubContent data prefix from dest file name"""
@@ -143,6 +147,8 @@ def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool:
143147
"""Determines if two files have the same file name."""
144148
return src_filename == dest_filename
145149

146-
def _is_alphabetically_earlier_file_name(self, src_filename: str, dest_filename: str) -> bool:
150+
def _is_alphabetically_earlier_file_name(
151+
self, src_filename: str, dest_filename: str
152+
) -> bool:
147153
"""Determines if one filename is alphabetically earlier than another."""
148154
return src_filename < dest_filename

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,36 @@
1717
from dataclasses import dataclass
1818
from datetime import datetime
1919

20-
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs, HubContentType
20+
from sagemaker.jumpstart.types import (
21+
JumpStartDataHolderType,
22+
JumpStartModelSpecs,
23+
HubContentType,
24+
)
25+
2126

2227
class CuratedHubUnsupportedFlag(str, Enum):
2328
"""Enum class for Curated Hub tag names."""
29+
2430
DEPRECATED_VERSIONS = "deprecated_versions"
2531
TRAINING_VULNERABLE_VERSIONS = "training_vulnerable_versions"
2632
INFERENCE_VULNERABLE_VERSIONS = "inference_vulnerable_versions"
2733

34+
2835
@dataclass
2936
class HubContentSummary:
30-
"""Dataclass to store HubContentSummary from List APIs."""
31-
hub_content_arn: str
32-
hub_content_name: str
33-
hub_content_version: str
34-
hub_content_type: HubContentType
35-
document_schema_version: str
36-
hub_content_status: str
37-
creation_time: str
38-
hub_content_display_name: str = None
39-
hub_content_description: str = None
40-
hub_content_search_keywords: List[str] = None
37+
"""Dataclass to store HubContentSummary from List APIs."""
38+
39+
hub_content_arn: str
40+
hub_content_name: str
41+
hub_content_version: str
42+
hub_content_type: HubContentType
43+
document_schema_version: str
44+
hub_content_status: str
45+
creation_time: str
46+
hub_content_display_name: str = None
47+
hub_content_description: str = None
48+
hub_content_search_keywords: List[str] = None
49+
4150

4251
@dataclass
4352
class S3ObjectLocation:
@@ -57,6 +66,7 @@ def get_uri(self) -> str:
5766
"""Returns the s3 URI"""
5867
return f"s3://{self.bucket}/{self.key}"
5968

69+
6070
@dataclass
6171
class JumpStartModelInfo:
6272
"""Helper class for storing JumpStart model info."""

0 commit comments

Comments
 (0)