@@ -389,7 +389,7 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
389
389
return json .loads (response ["Body" ].read ().decode ("utf-8" ))
390
390
391
391
392
- def scan_and_tag_models (self ) -> None :
392
+ def scan_and_tag_models (self , model_list : List [ Dict [ str , str ]] = None ) -> None :
393
393
"""Scans the Hub for JumpStart models and tags the HubContent.
394
394
395
395
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
@@ -402,11 +402,20 @@ def scan_and_tag_models(self) -> None:
402
402
For example, if model_a version_a is deprecated and inference is vulnerable, the
403
403
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
404
404
{"inference_vulnerable_versions": [version_a]}]
405
+
406
+ If models are passed in,
405
407
"""
406
408
JUMPSTART_LOGGER .info (
407
409
"Tagging models in hub: %s" , self .hub_name
408
410
)
409
- js_models_in_hub = [model for model in self .list_models () if get_jumpstart_model_and_version (model ) is not None ]
411
+ models_to_scan = model_list if model_list else self .list_models ()
412
+ if self ._is_invalid_model_list_input (model_list ):
413
+ raise ValueError (
414
+ "Model list should be a list of objects with values 'model_id'," ,
415
+ "and optional 'version'." ,
416
+ )
417
+
418
+ js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version (model ) is not None ]
410
419
tags_added : Dict [str , List [CuratedHubTag ]] = {}
411
420
for model in js_models_in_hub :
412
421
tags_to_add : List [CuratedHubTag ] = find_jumpstart_tags_for_hub_content (
0 commit comments