Skip to content

Commit 9dbe3d6

Browse files
committed
fix: Adding list to scan input
1 parent 7cda304 commit 9dbe3d6

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
389389
return json.loads(response["Body"].read().decode("utf-8"))
390390

391391

392-
def scan_and_tag_models(self) -> None:
392+
def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
393393
"""Scans the Hub for JumpStart models and tags the HubContent.
394394
395395
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
@@ -402,11 +402,20 @@ def scan_and_tag_models(self) -> None:
402402
For example, if model_a version_a is deprecated and inference is vulnerable, the
403403
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
404404
{"inference_vulnerable_versions": [version_a]}]
405+
406+
If models are passed in,
405407
"""
406408
JUMPSTART_LOGGER.info(
407409
"Tagging models in hub: %s", self.hub_name
408410
)
409-
js_models_in_hub = [model for model in self.list_models() if get_jumpstart_model_and_version(model) is not None]
411+
models_to_scan = model_list if model_list else self.list_models()
412+
if self._is_invalid_model_list_input(model_list):
413+
raise ValueError(
414+
"Model list should be a list of objects with values 'model_id',",
415+
"and optional 'version'.",
416+
)
417+
418+
js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None]
410419
tags_added: Dict[str, List[CuratedHubTag]] = {}
411420
for model in js_models_in_hub:
412421
tags_to_add: List[CuratedHubTag] = find_jumpstart_tags_for_hub_content(

0 commit comments

Comments
 (0)