Skip to content

Commit d21fd41

Browse files
committed
fix: more linting :/
1 parent 0deb20e commit d21fd41

File tree

9 files changed

+40
-73
lines changed

9 files changed

+40
-73
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def _retrieve_default_environment_variables(
133133
)
134134

135135
if gated_model_env_var is None and model_specs.is_gated_model():
136-
137136
possible_env_vars: Set[str] = {
138137
retrieve_gated_env_var_for_instance_type(instance_type)
139138
for instance_type in model_specs.supported_training_instance_types

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def _retrieve_model_package_arn(
8484
)
8585

8686
if scope == JumpStartScriptScope.INFERENCE:
87-
8887
instance_specific_arn: Optional[str] = (
8988
model_specs.hosting_instance_type_variants.get_model_package_arn(
9089
region=region, instance_type=instance_type
@@ -155,7 +154,6 @@ def _retrieve_model_package_model_artifact_s3_uri(
155154
"""
156155

157156
if scope == JumpStartScriptScope.TRAINING:
158-
159157
if region is None:
160158
region = JUMPSTART_DEFAULT_REGION_NAME
161159

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def _retrieve_model_uri(
149149
model_artifact_key: str
150150

151151
if model_scope == JumpStartScriptScope.INFERENCE:
152-
153152
is_prepacked = not model_specs.use_inference_script_uri()
154153

155154
model_artifact_key = (
@@ -159,7 +158,6 @@ def _retrieve_model_uri(
159158
)
160159

161160
elif model_scope == JumpStartScriptScope.TRAINING:
162-
163161
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
164162

165163
default_jumpstart_bucket: str = (

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def _retrieve_default_resources(
134134
}
135135

136136
if is_dynamic_container_deployment_supported:
137-
138137
all_resource_requirement_kwargs = {}
139138

140139
for (

src/sagemaker/jumpstart/cache.py

Lines changed: 38 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3535
MODEL_TYPE_TO_MANIFEST_MAP,
3636
MODEL_TYPE_TO_SPECS_MAP,
37-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3837
)
3938
from sagemaker.jumpstart.exceptions import (
4039
get_wildcard_model_version_msg,
@@ -178,9 +177,7 @@ def set_manifest_file_s3_key(
178177
}
179178
property_name = file_mapping.get(file_type)
180179
if not property_name:
181-
raise ValueError(
182-
self._file_type_error_msg(file_type, manifest_only=True)
183-
)
180+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
184181
if key != property_name:
185182
setattr(self, property_name, key)
186183
self.clear()
@@ -193,9 +190,7 @@ def get_manifest_file_s3_key(
193190
return self._manifest_file_s3_key
194191
if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST:
195192
return self._proprietary_manifest_s3_key
196-
raise ValueError(
197-
self._file_type_error_msg(file_type, manifest_only=True)
198-
)
193+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
199194

200195
def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
201196
"""Set s3 bucket used for cache."""
@@ -248,7 +243,8 @@ def _model_id_retrieval_function(
248243
sm_version = utils.get_sagemaker_version()
249244
manifest = self._content_cache.get(
250245
JumpStartCachedContentKey(
251-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
246+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
247+
)
252248
)[0].formatted_content
253249

254250
versions_compatible_with_sagemaker = [
@@ -265,7 +261,8 @@ def _model_id_retrieval_function(
265261
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
266262

267263
versions_incompatible_with_sagemaker = [
268-
Version(header.version) for header in manifest.values() # type: ignore
264+
Version(header.version)
265+
for header in manifest.values() # type: ignore
269266
if header.model_id == model_id
270267
]
271268
sm_incompatible_model_version = self._select_version(
@@ -295,9 +292,7 @@ def _model_id_retrieval_function(
295292
raise KeyError(error_msg)
296293

297294
error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
298-
error_msg += (
299-
f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
300-
)
295+
error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
301296

302297
other_model_id_version = None
303298
if model_type == JumpStartModelType.OPEN_WEIGHTS:
@@ -306,19 +301,17 @@ def _model_id_retrieval_function(
306301
) # all versions here are incompatible with sagemaker
307302
elif model_type == JumpStartModelType.PROPRIETARY:
308303
all_possible_model_id_version = [
309-
header.version for header in manifest.values() # type: ignore
304+
header.version
305+
for header in manifest.values() # type: ignore
310306
if header.model_id == model_id
311307
]
312308
other_model_id_version = (
313-
None
314-
if not all_possible_model_id_version
315-
else all_possible_model_id_version[0]
309+
None if not all_possible_model_id_version else all_possible_model_id_version[0]
316310
)
317311

318312
if other_model_id_version is not None:
319313
error_msg += (
320-
f"Consider using model ID '{model_id}' with version "
321-
f"'{other_model_id_version}'."
314+
f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'."
322315
)
323316
else:
324317
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
@@ -360,15 +353,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],
360353

361354
def _is_local_metadata_mode(self) -> bool:
362355
"""Returns True if the cache should use local metadata mode, based off env variables."""
363-
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
364-
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
365-
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
366-
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))
356+
return (
357+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
358+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
359+
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
360+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])
361+
)
367362

368363
def _get_json_file(
369-
self,
370-
key: str,
371-
filetype: JumpStartS3FileType
364+
self, key: str, filetype: JumpStartS3FileType
372365
) -> Tuple[Union[dict, list], Optional[str]]:
373366
"""Returns json file either from s3 or local file system.
374367
@@ -392,21 +385,19 @@ def _get_json_md5_hash(self, key: str):
392385
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
393386

394387
def _get_json_file_from_local_override(
395-
self,
396-
key: str,
397-
filetype: JumpStartS3FileType
388+
self, key: str, filetype: JumpStartS3FileType
398389
) -> Union[dict, list]:
399390
"""Reads json file from local filesystem and returns data."""
400391
if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
401-
metadata_local_root = (
402-
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
403-
)
392+
metadata_local_root = os.environ[
393+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
394+
]
404395
elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
405396
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
406397
else:
407398
raise ValueError(f"Unsupported file type for local override: {filetype}")
408399
file_path = os.path.join(metadata_local_root, key)
409-
with open(file_path, 'r') as f:
400+
with open(file_path, "r") as f:
410401
data = json.load(f)
411402
return data
412403

@@ -443,17 +434,15 @@ def _retrieval_function(
443434
formatted_content=utils.get_formatted_manifest(formatted_body),
444435
md5_hash=etag,
445436
)
446-
437+
447438
if data_type in {
448439
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
449440
JumpStartS3FileType.PROPRIETARY_SPECS,
450441
}:
451442
formatted_body, _ = self._get_json_file(id_info, data_type)
452443
model_specs = JumpStartModelSpecs(formatted_body)
453444
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
454-
return JumpStartCachedContentValue(
455-
formatted_content=model_specs
456-
)
445+
return JumpStartCachedContentValue(formatted_content=model_specs)
457446

458447
if data_type == HubContentType.MODEL:
459448
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
@@ -463,21 +452,15 @@ def _retrieval_function(
463452
hub_name=hub_name,
464453
hub_content_name=model_name,
465454
hub_content_version=model_version,
466-
hub_content_type=data_type
455+
hub_content_type=data_type,
467456
)
468457

469458
model_specs = JumpStartModelSpecs(
470459
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
471460
)
472461

473-
utils.emit_logs_based_on_model_specs(
474-
model_specs,
475-
self.get_region(),
476-
self._s3_client
477-
)
478-
return JumpStartCachedContentValue(
479-
formatted_content=model_specs
480-
)
462+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
463+
return JumpStartCachedContentValue(formatted_content=model_specs)
481464

482465
if data_type == HubType.HUB:
483466
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
@@ -487,9 +470,7 @@ def _retrieval_function(
487470
formatted_content=DescribeHubResponse(hub_description)
488471
)
489472

490-
raise ValueError(
491-
self._file_type_error_msg(data_type)
492-
)
473+
raise ValueError(self._file_type_error_msg(data_type))
493474

494475
def get_manifest(
495476
self,
@@ -498,7 +479,8 @@ def get_manifest(
498479
"""Return entire JumpStart models manifest."""
499480
manifest_dict = self._content_cache.get(
500481
JumpStartCachedContentKey(
501-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
482+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
483+
)
502484
)[0].formatted_content
503485
manifest = list(manifest_dict.values()) # type: ignore
504486
return manifest
@@ -555,16 +537,14 @@ def _select_version(
555537
except InvalidSpecifier:
556538
raise KeyError(f"Bad semantic version: {version_str}")
557539
available_versions_filtered = list(spec.filter(available_versions))
558-
return (
559-
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
560-
)
540+
return str(max(available_versions_filtered)) if available_versions_filtered != [] else None
561541

562542
def _get_header_impl(
563543
self,
564544
model_id: str,
565545
semantic_version_str: str,
566546
attempt: int = 0,
567-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
547+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
568548
) -> JumpStartModelHeader:
569549
"""Lower-level function to return header.
570550
@@ -587,7 +567,8 @@ def _get_header_impl(
587567

588568
manifest = self._content_cache.get(
589569
JumpStartCachedContentKey(
590-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
570+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
571+
)
591572
)[0].formatted_content
592573

593574
try:
@@ -603,7 +584,7 @@ def get_specs(
603584
self,
604585
model_id: str,
605586
version_str: str,
606-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
587+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
607588
) -> JumpStartModelSpecs:
608589
"""Return specs for a given JumpStart model ID and semantic version.
609590
@@ -616,16 +597,12 @@ def get_specs(
616597
header = self.get_header(model_id, version_str, model_type)
617598
spec_key = header.spec_key
618599
specs, cache_hit = self._content_cache.get(
619-
JumpStartCachedContentKey(
620-
MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key
621-
)
600+
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
622601
)
623602

624603
if not cache_hit and "*" in version_str:
625604
JUMPSTART_LOGGER.warning(
626-
get_wildcard_model_version_msg(
627-
header.model_id, version_str, header.version
628-
)
605+
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
629606
)
630607
return specs.formatted_content
631608

src/sagemaker/jumpstart/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,6 @@ def attach(
734734
"""
735735

736736
if model_id is None:
737-
738737
model_id, model_version = get_model_id_version_from_training_job(
739738
training_job_name=training_job_name, sagemaker_session=sagemaker_session
740739
)

src/sagemaker/jumpstart/factory/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
330330
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
331331
sagemaker_session=kwargs.sagemaker_session,
332332
):
333-
334333
entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
335334

336335
kwargs.entry_point = entry_point

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
379379
is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys
380380

381381
def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]:
382-
383382
copied_filter = copy.deepcopy(filter)
384383

385384
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,6 @@ def resolve_estimator_sagemaker_config_field(
743743
# JumpStart Estimators have certain default field values. We want
744744
# sagemaker config values to take priority over the model-specific defaults.
745745
if field_name == "enable_network_isolation":
746-
747746
resolved_val = resolve_value_from_config(
748747
direct_input=None,
749748
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
@@ -754,7 +753,6 @@ def resolve_estimator_sagemaker_config_field(
754753
return resolved_val if resolved_val is not None else field_val
755754

756755
if field_name == "encrypt_inter_container_traffic":
757-
758756
resolved_val = resolve_value_from_config(
759757
direct_input=None,
760758
config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
@@ -868,12 +866,13 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
868866
"""Returns the Studio Spec file prefix given a model ID and version."""
869867
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"
870868

869+
871870
def extract_info_from_hub_content_arn(
872871
arn: str,
873872
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
874873
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
875874

876-
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
875+
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
877876
if match:
878877
hub_name = match.group(4)
879878
hub_region = match.group(2)

0 commit comments

Comments
 (0)