Skip to content

Commit 4b82389

Browse files
committed
fix: linter
1 parent 0e5d5b4 commit 4b82389

File tree

7 files changed

+55
-108
lines changed

7 files changed

+55
-108
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def generate_file_infos_from_model_specs(
6666
)
6767
files = []
6868
for dependency in HubContentDependencyType:
69-
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(
70-
dependency
71-
)
69+
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
7270
location_type = "prefix" if location.key.endswith("/") else "object"
7371

7472
if location_type == "prefix":
@@ -93,7 +91,5 @@ def generate_file_infos_from_model_specs(
9391
response = s3_client.head_object(**parameters)
9492
size = response.get("ContentLength")
9593
last_updated = response.get("LastModified")
96-
files.append(
97-
FileInfo(location.bucket, location.key, size, last_updated, dependency)
98-
)
94+
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency))
9995
return files

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def training_script_s3_reference(self):
7373
@property
7474
def default_training_dataset_s3_reference(self):
7575
"""Retrieves s3 reference for s3 directory containing model training datasets"""
76-
return S3ObjectLocation(
77-
self._get_bucket_name(), self.__get_training_dataset_prefix()
78-
)
76+
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
7977

8078
@property
8179
def demo_notebook_s3_reference(self):

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515
from concurrent import futures
16-
from functools import lru_cache
1716
from datetime import datetime
1817
import json
1918
import traceback
2019
from typing import Optional, Dict, List, Any
21-
2220
import boto3
2321
from botocore import exceptions
2422
from botocore.client import BaseClient
@@ -102,9 +100,7 @@ def _fetch_hub_bucket_name(self) -> str:
102100
if hub_output_location:
103101
location = create_s3_object_reference_from_uri(hub_output_location)
104102
return location.bucket
105-
default_bucket_name = generate_default_hub_bucket_name(
106-
self._sagemaker_session
107-
)
103+
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
108104
JUMPSTART_LOGGER.warning(
109105
"There is not a Hub bucket associated with %s. Using %s",
110106
self.hub_name,
@@ -124,9 +120,7 @@ def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> N
124120
"""Generates an ``S3ObjectLocation`` given a Hub name."""
125121
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
126122
curr_timestamp = datetime.now().timestamp()
127-
return S3ObjectLocation(
128-
bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}"
129-
)
123+
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
130124

131125
def create(
132126
self,
@@ -180,9 +174,7 @@ def describe_model(
180174
) -> DescribeHubContentsResponse:
181175
"""Returns descriptive information about the Hub Model"""
182176

183-
hub_content_description: Dict[
184-
str, Any
185-
] = self._sagemaker_session.describe_hub_content(
177+
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
186178
hub_name=self.hub_name,
187179
hub_content_name=model_name,
188180
hub_content_version=model_version,
@@ -228,11 +220,7 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
228220

229221
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
230222
hub_models = summary_list_from_list_api_response(self.list_models())
231-
return [
232-
model
233-
for model in hub_models
234-
if get_jumpstart_model_and_version(model) is not None
235-
]
223+
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
236224

237225
def _determine_models_to_sync(
238226
self,
@@ -301,18 +289,12 @@ def sync(self, model_list: List[Dict[str, str]]):
301289
model["model_id"],
302290
model["version"],
303291
)
304-
model_version_list.append(
305-
JumpStartModelInfo(model["model_id"], model["version"])
306-
)
292+
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))
307293

308294
js_models_in_hub = self._get_jumpstart_models_in_hub()
309-
mapped_models_in_hub = {
310-
model.hub_content_name: model for model in js_models_in_hub
311-
}
295+
mapped_models_in_hub = {model.hub_content_name: model for model in js_models_in_hub}
312296

313-
models_to_sync = self._determine_models_to_sync(
314-
model_version_list, mapped_models_in_hub
315-
)
297+
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
316298
JUMPSTART_LOGGER.warning(
317299
"Syncing the following models into Hub %s: %s",
318300
self.hub_name,
@@ -328,9 +310,7 @@ def sync(self, model_list: List[Dict[str, str]]):
328310
thread_name_prefix="import-models-to-curated-hub",
329311
) as import_executor:
330312
for thread_num, model in enumerate(models_to_sync):
331-
task = import_executor.submit(
332-
self._sync_public_model_to_hub, model, thread_num
333-
)
313+
task = import_executor.submit(self._sync_public_model_to_hub, model, thread_num)
334314
tasks.append(task)
335315

336316
# Handle failed imports
@@ -343,9 +323,7 @@ def sync(self, model_list: List[Dict[str, str]]):
343323
{
344324
"Exception": exception,
345325
"Traceback": "".join(
346-
traceback.TracebackException.from_exception(
347-
exception
348-
).format()
326+
traceback.TracebackException.from_exception(exception).format()
349327
),
350328
}
351329
)
@@ -389,9 +367,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
389367
label=dest_location.key,
390368
).execute()
391369
else:
392-
JUMPSTART_LOGGER.warning(
393-
"Nothing to copy for %s v%s", model.model_id, model.version
394-
)
370+
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version)
395371

396372
# TODO: Tag model if specs say it is deprecated or training/inference
397373
# vulnerable. Update tag of HubContent ARN without version.
@@ -457,14 +433,10 @@ def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
457433

458434
models_to_scan = model_list if model_list else self.list_models()
459435
js_models_in_hub = [
460-
model
461-
for model in models_to_scan
462-
if get_jumpstart_model_and_version(model) is not None
436+
model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None
463437
]
464438
for model in js_models_in_hub:
465-
tags_to_add: List[
466-
TagsDict
467-
] = find_unsupported_flags_for_hub_content_versions(
439+
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions(
468440
hub_name=self.hub_name,
469441
hub_content_name=model.hub_content_name,
470442
region=self.region,

src/sagemaker/jumpstart/curated_hub/sync/request.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def __init__(
7272
# Need the file lists to be sorted for comparisons below
7373
self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key)
7474
formatted_dest_files = [self._format_dest_file(file) for file in dest_files]
75-
self.dest_files: List[FileInfo] = sorted(
76-
formatted_dest_files, key=lambda x: x.location.key
77-
)
75+
self.dest_files: List[FileInfo] = sorted(formatted_dest_files, key=lambda x: x.location.key)
7876

7977
def _format_dest_file(self, file: FileInfo) -> FileInfo:
8078
"""Strips HubContent data prefix from dest file name"""
@@ -147,8 +145,6 @@ def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool:
147145
"""Determines if two files have the same file name."""
148146
return src_filename == dest_filename
149147

150-
def _is_alphabetically_earlier_file_name(
151-
self, src_filename: str, dest_filename: str
152-
) -> bool:
148+
def _is_alphabetically_earlier_file_name(self, src_filename: str, dest_filename: str) -> bool:
153149
"""Determines if one filename is alphabetically earlier than another."""
154150
return src_filename < dest_filename

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
2121
from sagemaker.utils import aws_partition
22-
from typing import Optional, Dict, List, Any, Set
23-
from botocore.exceptions import ClientError
22+
from typing import Optional, Dict, List, Any
2423
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
2524
from sagemaker.jumpstart.curated_hub.types import (
2625
CuratedHubUnsupportedFlag,
@@ -29,13 +28,10 @@
2928
)
3029
from sagemaker.jumpstart import constants
3130
from sagemaker.jumpstart import utils
32-
from sagemaker.session import Session
3331
from sagemaker.jumpstart.enums import JumpStartScriptScope
3432
from sagemaker.jumpstart.curated_hub.constants import (
3533
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX,
3634
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX,
37-
TASK_TAG_PREFIX,
38-
FRAMEWORK_TAG_PREFIX,
3935
)
4036
from sagemaker.utils import format_tags, TagsDict
4137

@@ -95,9 +91,7 @@ def construct_hub_arn_from_name(
9591
return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}"
9692

9793

98-
def construct_hub_model_arn_from_inputs(
99-
hub_arn: str, model_name: str, version: str
100-
) -> str:
94+
def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str:
10195
"""Constructs a HubContent model arn from the Hub name, model name, and model version."""
10296

10397
info = get_info_from_hub_resource_arn(hub_arn)
@@ -127,9 +121,7 @@ def generate_hub_arn_for_init_kwargs(
127121
if match:
128122
hub_arn = hub_name
129123
else:
130-
hub_arn = construct_hub_arn_from_name(
131-
hub_name=hub_name, region=region, session=session
132-
)
124+
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
133125
return hub_arn
134126

135127

@@ -185,21 +177,18 @@ def create_hub_bucket_if_it_does_not_exist(
185177
return bucket_name
186178

187179

188-
def tag_hub_content(
189-
hub_content_arn: str, tags: List[TagsDict], session: Session
190-
) -> None:
180+
def tag_hub_content(hub_content_arn: str, tags: List[TagsDict], session: Session) -> None:
191181
session.add_tags(ResourceArn=hub_content_arn, Tags=tags)
192-
JUMPSTART_LOGGER.info(
193-
f"Added tags to HubContentArn %s: %s", hub_content_arn, TagsDict
194-
)
182+
JUMPSTART_LOGGER.info("Added tags to HubContentArn %s: %s", hub_content_arn, TagsDict)
195183

196184

197185
def find_unsupported_flags_for_hub_content_versions(
198186
hub_name: str, hub_content_name: str, region: str, session: Session
199187
) -> List[TagsDict]:
200188
"""Finds the JumpStart public hub model for a HubContent and calculates relevant tags.
201189
202-
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
190+
Since tags are the same for all versions of a HubContent,
191+
these tags will map from the key to a list of versions impacted.
203192
For example, if certain public hub model versions are deprecated,
204193
this utility will return a `deprecated` tag mapped to the deprecated versions for the HubContent.
205194
"""
@@ -242,7 +231,8 @@ def find_unsupported_flags_for_model_version(
242231
"""Finds relevant CuratedHubTags for a version of a JumpStart public hub model.
243232
244233
For example, if the public hub model is deprecated, this utility will return a `deprecated` tag.
245-
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
234+
Since tags are the same for all versions of a HubContent,
235+
these tags will map from the key to a list of versions impacted.
246236
"""
247237
flags_to_add: List[CuratedHubUnsupportedFlag] = []
248238
jumpstart_model_specs = utils.verify_model_region_and_return_specs(
@@ -293,14 +283,10 @@ def get_jumpstart_model_and_version(
293283
jumpstart_model_version = jumpstart_model_version_tag[
294284
len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) :
295285
]
296-
return JumpStartModelInfo(
297-
model_id=jumpstart_model_id, version=jumpstart_model_version
298-
)
286+
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
299287

300288

301-
def summary_from_list_api_response(
302-
hub_content_summary: Dict[str, Any]
303-
) -> HubContentSummary:
289+
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
304290
return HubContentSummary(
305291
hub_content_arn=hub_content_summary.get("HubContentArn"),
306292
hub_content_name=hub_content_summary.get("HubContentName"),

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation, HubContentSummary
2222
from sagemaker.jumpstart.types import JumpStartModelSpecs
2323
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC
24-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, HubContentType
24+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2525

2626

2727
REGION = "us-east-1"
@@ -552,4 +552,4 @@ def test_determine_models_to_sync(sagemaker_session):
552552
}
553553
# Old model_one, same model_two
554554
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
555-
assert res == [model_one]
555+
assert res == [model_one]

0 commit comments

Comments
 (0)