Skip to content

Commit a73b5b2

Browse files
committed
fix: linters
1 parent 42efbaa commit a73b5b2

File tree

6 files changed

+69
-57
lines changed

6 files changed

+69
-57
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,15 @@
4949
create_hub_bucket_if_it_does_not_exist,
5050
generate_default_hub_bucket_name,
5151
create_s3_object_reference_from_uri,
52-
tag_hub_content,
5352
get_jumpstart_model_and_version,
54-
find_unsupported_flags_for_hub_content_versions,
55-
summary_list_from_list_api_response,
53+
find_deprecated_vulnerable_flags_for_hub_content,
5654
)
5755
from sagemaker.jumpstart.curated_hub.types import (
5856
HubContentDocument_v2,
5957
JumpStartModelInfo,
6058
S3ObjectLocation,
6159
HubContentSummary,
60+
summary_list_from_list_api_response,
6261
)
6362
from sagemaker.utils import TagsDict
6463

@@ -201,6 +200,8 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool
201200
202201
`model_list` objects must have `model_id` (str) and optional `version` (str).
203202
"""
203+
if model_list is None:
204+
return True
204205
for obj in model_list:
205206
if not isinstance(obj.get("model_id"), str):
206207
return True
@@ -409,43 +410,54 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
409410
)
410411
return json.loads(response["Body"].read().decode("utf-8"))
411412

412-
def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
413+
def scan_and_tag_models(self, model_ids: List[str] = None) -> None:
413414
"""Scans the Hub for JumpStart models and tags the HubContent.
414415
415416
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
416417
The tags that will be added are based off the specifications in the JumpStart public hub:
417418
1. "deprecated_versions" -> If the public hub model is deprecated
418-
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities
419-
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities
419+
2. "inference_vulnerable_versions" -> If the inference script has vulnerabilities
420+
3. "training_vulnerable_versions" -> If the training script has vulnerabilities
420421
421422
The tag value will be a list of versions in the Curated Hub that fall under those keys.
422423
For example, if model_a version_a is deprecated and inference is vulnerable, the
423424
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
424425
{"inference_vulnerable_versions": [version_a]}]
425426
426-
If models are passed in,
427+
If models are passed in, this will only scan those models if they exist in the Curated Hub.
427428
"""
428429
JUMPSTART_LOGGER.info("Tagging models in hub: %s", self.hub_name)
429-
if self._is_invalid_model_list_input(model_list):
430+
model_ids = model_ids if model_ids is not None else []
431+
if self._is_invalid_model_list_input(model_ids):
430432
raise ValueError(
431433
"Model list should be a list of objects with values 'model_id',",
432434
"and optional 'version'.",
433435
)
434436

435-
models_to_scan = model_list if model_list else self.list_models()
437+
models_in_hub = summary_list_from_list_api_response(self.list_models(clear_cache=False))
438+
439+
model_summaries_to_scan = models_in_hub
440+
if model_ids:
441+
model_summaries_to_scan = list(
442+
filter(
443+
lambda model_summary: model_summary.hub_content_name in model_ids, models_in_hub
444+
)
445+
)
446+
436447
js_models_in_hub = [
437-
model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None
448+
model
449+
for model in model_summaries_to_scan
450+
if get_jumpstart_model_and_version(model) is not None
438451
]
439452
for model in js_models_in_hub:
440-
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions(
453+
tags_to_add: List[TagsDict] = find_deprecated_vulnerable_flags_for_hub_content(
441454
hub_name=self.hub_name,
442455
hub_content_name=model.hub_content_name,
443456
region=self.region,
444457
session=self._sagemaker_session,
445458
)
446-
tag_hub_content(
447-
hub_content_arn=model.hub_content_arn,
448-
tags=tags_to_add,
449-
session=self._sagemaker_session,
459+
self._sagemaker_session.add_tags(ResourceArn=model.hub_content_arn, Tags=tags_to_add)
460+
JUMPSTART_LOGGER.info(
461+
"Added tags to HubContentArn %s: %s", model.hub_content_arn, tags_to_add
450462
)
451463
JUMPSTART_LOGGER.info("Tagging complete!")

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,40 @@ class HubContentSummary:
4242
hub_content_type: HubContentType
4343
document_schema_version: str
4444
hub_content_status: str
45-
creation_time: str
45+
creation_time: datetime
4646
hub_content_display_name: str = None
4747
hub_content_description: str = None
4848
hub_content_search_keywords: List[str] = None
4949

5050

51+
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
52+
"""Creates a single HubContentSummary from a HubContentSummary from the HubService List APIs."""
53+
return HubContentSummary(
54+
hub_content_arn=hub_content_summary.get("HubContentArn"),
55+
hub_content_name=hub_content_summary.get("HubContentName"),
56+
hub_content_version=hub_content_summary.get("HubContentVersion"),
57+
hub_content_type=hub_content_summary.get("HubContentType"),
58+
document_schema_version=hub_content_summary.get("DocumentSchemaVersion"),
59+
hub_content_status=hub_content_summary.get("HubContentStatus"),
60+
hub_content_display_name=hub_content_summary.get("HubContentDisplayName"),
61+
hub_content_description=hub_content_summary.get("HubContentDescription"),
62+
hub_content_search_keywords=hub_content_summary.get("HubContentSearchKeywords"),
63+
creation_time=hub_content_summary.get("CreationTime"),
64+
)
65+
66+
67+
def summary_list_from_list_api_response(
68+
list_hub_contents_response: Dict[str, Any]
69+
) -> List[HubContentSummary]:
70+
"""Creates a HubContentSummary list from either the ListHubContent or ListHubContentVersions API response."""
71+
return list(
72+
map(
73+
summary_from_list_api_response,
74+
list_hub_contents_response["HubContentSummaries"],
75+
)
76+
)
77+
78+
5179
@dataclass
5280
class S3ObjectLocation:
5381
"""Helper class for S3 object references"""

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import re
1616
from typing import Optional, Dict, List, Any
1717
from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation
18-
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
1918
from sagemaker.s3_utils import parse_s3_url
2019
from sagemaker.session import Session
2120
from sagemaker.utils import aws_partition
@@ -24,6 +23,7 @@
2423
CuratedHubUnsupportedFlag,
2524
HubContentSummary,
2625
JumpStartModelInfo,
26+
summary_list_from_list_api_response,
2727
)
2828
from sagemaker.jumpstart import constants
2929
from sagemaker.jumpstart import utils
@@ -176,12 +176,7 @@ def create_hub_bucket_if_it_does_not_exist(
176176
return bucket_name
177177

178178

179-
def tag_hub_content(hub_content_arn: str, tags: List[TagsDict], session: Session) -> None:
180-
session.add_tags(ResourceArn=hub_content_arn, Tags=tags)
181-
JUMPSTART_LOGGER.info("Added tags to HubContentArn %s: %s", hub_content_arn, TagsDict)
182-
183-
184-
def find_unsupported_flags_for_hub_content_versions(
179+
def find_deprecated_vulnerable_flags_for_hub_content(
185180
hub_name: str, hub_content_name: str, region: str, session: Session
186181
) -> List[TagsDict]:
187182
"""Finds the JumpStart public hub model for a HubContent and calculates relevant tags.
@@ -304,31 +299,3 @@ def get_jumpstart_model_and_version(
304299
len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) :
305300
]
306301
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
307-
308-
309-
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
310-
"""Creates a single HubContentSummary from a HubContentSummary from the HubService List APIs."""
311-
return HubContentSummary(
312-
hub_content_arn=hub_content_summary.get("HubContentArn"),
313-
hub_content_name=hub_content_summary.get("HubContentName"),
314-
hub_content_version=hub_content_summary.get("HubContentVersion"),
315-
hub_content_type=hub_content_summary.get("HubContentType"),
316-
document_schema_version=hub_content_summary.get("DocumentSchemaVersion"),
317-
hub_content_status=hub_content_summary.get("HubContentStatus"),
318-
hub_content_display_name=hub_content_summary.get("HubContentDisplayName"),
319-
hub_content_description=hub_content_summary.get("HubContentDescription"),
320-
hub_content_search_keywords=hub_content_summary.get("HubContentSearchKeywords"),
321-
creation_time=hub_content_summary.get("CreationTime"),
322-
)
323-
324-
325-
def summary_list_from_list_api_response(
326-
list_hub_contents_response: Dict[str, Any]
327-
) -> List[HubContentSummary]:
328-
"""Creates a HubContentSummary list from either the ListHubContent or ListHubContentVersions API response."""
329-
return list(
330-
map(
331-
summary_from_list_api_response,
332-
list_hub_contents_response["HubContentSummaries"],
333-
)
334-
)

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def main(sys_args=None):
6565
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6666

6767
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68-
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69-
client_sagemaker_pysdk_version
70-
)
7168

7269
user = getpass.getuser()
7370
if user != "root":

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
import pytest
1919
from mock import Mock
2020
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
21-
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation, HubContentSummary
21+
from sagemaker.jumpstart.curated_hub.types import (
22+
JumpStartModelInfo,
23+
S3ObjectLocation,
24+
HubContentSummary
25+
)
2226
from sagemaker.jumpstart.types import JumpStartModelSpecs
2327
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC
2428
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
@@ -202,7 +206,11 @@ def test_sync_filters_models_that_exist_in_hub(
202206
"@jumpstart-model-version:1.0.2",
203207
],
204208
},
205-
{"HubContentName": "mock-model-three-nonsense", "HubContentVersion": "1.0.2", "HubContentSearchKeywords": []},
209+
{
210+
"HubContentName": "mock-model-three-nonsense",
211+
"HubContentVersion": "1.0.2",
212+
"HubContentSearchKeywords": []
213+
},
206214
{
207215
"HubContentName": "mock-model-four-huggingface",
208216
"HubContentVersion": "2.0.2",

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
324324
mock_specs.training_vulnerable = True
325325
mock_spec_util.return_value = mock_specs
326326

327-
tags = utils.find_unsupported_flags_for_hub_content_versions(
327+
tags = utils.find_deprecated_vulnerable_flags_for_hub_content(
328328
hub_name="test",
329329
hub_content_name="test",
330330
region="test",

0 commit comments

Comments
 (0)