Skip to content

Commit 0bcf25e

Browse files
authored
fix: improve notebook utils logging and add model-specific info messages
2 parents 3528e48 + b1789fb commit 0bcf25e

File tree

9 files changed

+205
-111
lines changed

9 files changed

+205
-111
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Dict, Set, Type
1718
import boto3
1819
from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer
@@ -33,6 +34,8 @@
3334
from sagemaker.session import Session
3435

3536

37+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING"
38+
3639
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
3740
[
3841
JumpStartLaunchedRegionInfo(
@@ -209,6 +212,19 @@
209212

210213
JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")
211214

215+
# disable logging if env var is set
216+
JUMPSTART_LOGGER.addHandler(
217+
type(
218+
"",
219+
(logging.StreamHandler,),
220+
{
221+
"emit": lambda self, *args, **kwargs: logging.StreamHandler.emit(self, *args, **kwargs)
222+
if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING)
223+
else None
224+
},
225+
)()
226+
)
227+
212228
try:
213229
DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session(
214230
boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)

src/sagemaker/jumpstart/filters.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ class SpecialSupportedFilterKeys(str, Enum):
4545

4646
TASK = "task"
4747
FRAMEWORK = "framework"
48-
SUPPORTED_MODEL = "supported_model"
4948

5049

5150
FILTER_OPERATOR_STRING_MAPPINGS = {
@@ -74,7 +73,6 @@ class SpecialSupportedFilterKeys(str, Enum):
7473
[
7574
SpecialSupportedFilterKeys.TASK,
7675
SpecialSupportedFilterKeys.FRAMEWORK,
77-
SpecialSupportedFilterKeys.SUPPORTED_MODEL,
7876
]
7977
)
8078

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 127 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
import copy
1616

1717
from functools import cmp_to_key
18+
import os
1819
from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
1920
from packaging.version import Version
2021
from sagemaker.jumpstart import accessors
21-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
22+
from sagemaker.jumpstart.constants import (
23+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING,
24+
JUMPSTART_DEFAULT_REGION_NAME,
25+
)
2226
from sagemaker.jumpstart.enums import JumpStartScriptScope
2327
from sagemaker.jumpstart.filters import (
2428
SPECIAL_SUPPORTED_FILTER_KEYS,
@@ -281,126 +285,160 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
281285
results. (Default: False).
282286
"""
283287

284-
if isinstance(filter, str):
285-
filter = Identity(filter)
288+
class _ModelSearchContext:
289+
"""Context manager for conducting model searches."""
286290

287-
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
288-
manifest_keys = set(models_manifest_list[0].__slots__)
291+
def __init__(self):
292+
"""Initialize context manager."""
289293

290-
all_keys: Set[str] = set()
294+
self.old_disable_js_logging_env_var_value = os.environ.get(
295+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING
296+
)
291297

292-
model_filters: Set[ModelFilter] = set()
298+
def __enter__(self, *args, **kwargs):
299+
"""Enter context.
293300
294-
for operator in _model_filter_in_operator_generator(filter):
295-
model_filter = operator.unresolved_value
296-
key = model_filter.key
297-
all_keys.add(key)
298-
model_filters.add(model_filter)
301+
JumpStart logs get disabled to avoid excessive logging.
302+
"""
299303

300-
for key in all_keys:
301-
if "." in key:
302-
raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').")
304+
os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true"
303305

304-
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
306+
def __exit__(self, *args, **kwargs):
307+
"""Exit context.
305308
306-
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
307-
possible_spec_keys = metadata_filter_keys - manifest_keys
309+
Restore JumpStart logging settings, and reset cache so
310+
new logs would appear for models previously searched.
311+
"""
308312

309-
unrecognized_keys: Set[str] = set()
313+
if self.old_disable_js_logging_env_var_value:
314+
os.environ[
315+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING
316+
] = self.old_disable_js_logging_env_var_value
317+
else:
318+
os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None)
319+
accessors.JumpStartModelsAccessor.reset_cache()
310320

311-
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
312-
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
313-
is_supported_model_filter = SpecialSupportedFilterKeys.SUPPORTED_MODEL in all_keys
321+
with _ModelSearchContext():
314322

315-
for model_manifest in models_manifest_list:
323+
if isinstance(filter, str):
324+
filter = Identity(filter)
316325

317-
copied_filter = copy.deepcopy(filter)
326+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
327+
manifest_keys = set(models_manifest_list[0].__slots__)
318328

319-
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
329+
all_keys: Set[str] = set()
320330

321-
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
331+
model_filters: Set[ModelFilter] = set()
322332

323-
for val in required_manifest_keys:
324-
manifest_specs_cached_values[val] = getattr(model_manifest, val)
333+
for operator in _model_filter_in_operator_generator(filter):
334+
model_filter = operator.unresolved_value
335+
key = model_filter.key
336+
all_keys.add(key)
337+
model_filters.add(model_filter)
325338

326-
if is_task_filter:
327-
manifest_specs_cached_values[
328-
SpecialSupportedFilterKeys.TASK
329-
] = extract_framework_task_model(model_manifest.model_id)[1]
339+
for key in all_keys:
340+
if "." in key:
341+
raise NotImplementedError(
342+
f"No support for multiple level metadata indexing ('{key}')."
343+
)
330344

331-
if is_framework_filter:
332-
manifest_specs_cached_values[
333-
SpecialSupportedFilterKeys.FRAMEWORK
334-
] = extract_framework_task_model(model_manifest.model_id)[0]
345+
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
335346

336-
if is_supported_model_filter:
337-
manifest_specs_cached_values[SpecialSupportedFilterKeys.SUPPORTED_MODEL] = Version(
338-
model_manifest.min_version
339-
) <= Version(get_sagemaker_version())
347+
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
348+
possible_spec_keys = metadata_filter_keys - manifest_keys
340349

341-
_populate_model_filters_to_resolved_values(
342-
manifest_specs_cached_values,
343-
model_filters_to_resolved_values,
344-
model_filters,
345-
)
350+
unrecognized_keys: Set[str] = set()
346351

347-
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
352+
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
353+
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
348354

349-
copied_filter.eval()
355+
for model_manifest in models_manifest_list:
350356

351-
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
352-
if copied_filter.resolved_value == BooleanValues.TRUE:
353-
yield (model_manifest.model_id, model_manifest.version)
354-
continue
357+
copied_filter = copy.deepcopy(filter)
355358

356-
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
357-
raise RuntimeError(
358-
"Filter expression in unevaluated state after using values from model manifest. "
359-
"Model ID and version that is failing: "
360-
f"{(model_manifest.model_id, model_manifest.version)}."
359+
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
360+
361+
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
362+
363+
for val in required_manifest_keys:
364+
manifest_specs_cached_values[val] = getattr(model_manifest, val)
365+
366+
if is_task_filter:
367+
manifest_specs_cached_values[
368+
SpecialSupportedFilterKeys.TASK
369+
] = extract_framework_task_model(model_manifest.model_id)[1]
370+
371+
if is_framework_filter:
372+
manifest_specs_cached_values[
373+
SpecialSupportedFilterKeys.FRAMEWORK
374+
] = extract_framework_task_model(model_manifest.model_id)[0]
375+
376+
if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
377+
continue
378+
379+
_populate_model_filters_to_resolved_values(
380+
manifest_specs_cached_values,
381+
model_filters_to_resolved_values,
382+
model_filters,
361383
)
362-
copied_filter_2 = copy.deepcopy(filter)
363384

364-
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
365-
region=region,
366-
model_id=model_manifest.model_id,
367-
version=model_manifest.version,
368-
)
385+
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
369386

370-
model_specs_keys = set(model_specs.__slots__)
387+
copied_filter.eval()
371388

372-
unrecognized_keys -= model_specs_keys
373-
unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys
374-
unrecognized_keys.update(unrecognized_keys_for_single_spec)
389+
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
390+
if copied_filter.resolved_value == BooleanValues.TRUE:
391+
yield (model_manifest.model_id, model_manifest.version)
392+
continue
375393

376-
for val in possible_spec_keys:
377-
if hasattr(model_specs, val):
378-
manifest_specs_cached_values[val] = getattr(model_specs, val)
394+
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
395+
raise RuntimeError(
396+
"Filter expression in unevaluated state after using "
397+
"values from model manifest. Model ID and version that "
398+
f"is failing: {(model_manifest.model_id, model_manifest.version)}."
399+
)
400+
copied_filter_2 = copy.deepcopy(filter)
379401

380-
_populate_model_filters_to_resolved_values(
381-
manifest_specs_cached_values,
382-
model_filters_to_resolved_values,
383-
model_filters,
384-
)
385-
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
402+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
403+
region=region,
404+
model_id=model_manifest.model_id,
405+
version=model_manifest.version,
406+
)
386407

387-
copied_filter_2.eval()
408+
model_specs_keys = set(model_specs.__slots__)
388409

389-
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
390-
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
391-
BooleanValues.UNKNOWN and list_incomplete_models
392-
):
393-
yield (model_manifest.model_id, model_manifest.version)
394-
continue
410+
unrecognized_keys -= model_specs_keys
411+
unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys
412+
unrecognized_keys.update(unrecognized_keys_for_single_spec)
395413

396-
raise RuntimeError(
397-
"Filter expression in unevaluated state after using values from model specs. "
398-
"Model ID and version that is failing: "
399-
f"{(model_manifest.model_id, model_manifest.version)}."
400-
)
414+
for val in possible_spec_keys:
415+
if hasattr(model_specs, val):
416+
manifest_specs_cached_values[val] = getattr(model_specs, val)
417+
418+
_populate_model_filters_to_resolved_values(
419+
manifest_specs_cached_values,
420+
model_filters_to_resolved_values,
421+
model_filters,
422+
)
423+
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
424+
425+
copied_filter_2.eval()
426+
427+
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
428+
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
429+
BooleanValues.UNKNOWN and list_incomplete_models
430+
):
431+
yield (model_manifest.model_id, model_manifest.version)
432+
continue
433+
434+
raise RuntimeError(
435+
"Filter expression in unevaluated state after using values from model specs. "
436+
"Model ID and version that is failing: "
437+
f"{(model_manifest.model_id, model_manifest.version)}."
438+
)
401439

402-
if len(unrecognized_keys) > 0:
403-
raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}")
440+
if len(unrecognized_keys) > 0:
441+
raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}")
404442

405443

406444
def get_model_url(

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
730730
"training_dependencies",
731731
"training_vulnerabilities",
732732
"deprecated",
733+
"info_message",
733734
"deprecated_message",
734735
"deprecate_warn_message",
735736
"default_inference_instance_type",
@@ -801,6 +802,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
801802
self.deprecated: bool = bool(json_obj["deprecated"])
802803
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
803804
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
805+
self.info_message: Optional[str] = json_obj.get("info_message")
804806
self.default_inference_instance_type: Optional[str] = json_obj.get(
805807
"default_inference_instance_type"
806808
)

src/sagemaker/jumpstart/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,9 @@ def emit_logs_based_on_model_specs(
511511
if model_specs.deprecate_warn_message:
512512
constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message)
513513

514+
if model_specs.info_message:
515+
constants.JUMPSTART_LOGGER.info(model_specs.info_message)
516+
514517
if model_specs.inference_vulnerable or model_specs.training_vulnerable:
515518
constants.JUMPSTART_LOGGER.warning(
516519
"Using vulnerable JumpStart model '%s' and version '%s'.",

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6221,6 +6221,7 @@
62216221
"ml.c5.2xlarge",
62226222
],
62236223
"hosting_use_script_uri": True,
6224+
"info_message": None,
62246225
"metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
62256226
"model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"},
62266227
"deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"},

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -383,34 +383,27 @@ def test_list_jumpstart_models_region(
383383
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
384384
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
385385
@patch("sagemaker.jumpstart.notebook_utils.get_sagemaker_version")
386-
def test_list_jumpstart_models_unsupported_models(
386+
@patch("sagemaker.jumpstart.notebook_utils.accessors.JumpStartModelsAccessor.reset_cache")
387+
@patch.dict("os.environ", {})
388+
@patch("logging.StreamHandler.emit")
389+
@patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False)
390+
def test_list_jumpstart_models_disables_logging_resets_cache(
387391
self,
392+
patched_emit: Mock,
393+
patched_reset_cache: Mock,
388394
patched_get_sagemaker_version: Mock,
389395
patched_get_model_specs: Mock,
390396
patched_get_manifest: Mock,
391397
):
392398
patched_get_model_specs.side_effect = get_prototype_model_spec
393399
patched_get_manifest.side_effect = get_prototype_manifest
394400

395-
patched_get_sagemaker_version.return_value = "0.0.0"
401+
patched_get_sagemaker_version.return_value = "3.0.0"
396402

397-
assert [] == list_jumpstart_models("supported_model == True")
398-
patched_get_model_specs.assert_not_called()
399-
assert [] == list_jumpstart_models(
400-
And("supported_model == True", "training_supported in [False, True]")
401-
)
402-
patched_get_model_specs.assert_not_called()
403-
404-
assert [] != list_jumpstart_models("supported_model == False")
405-
406-
patched_get_sagemaker_version.return_value = "999999.0.0"
407-
408-
assert [] != list_jumpstart_models("supported_model == True")
409-
410-
patched_get_model_specs.reset_mock()
403+
list_jumpstart_models("deprecate_warn_message is blah")
411404

412-
assert [] != list_jumpstart_models("training_supported in [False, True]")
413-
patched_get_model_specs.assert_called()
405+
patched_emit.assert_not_called()
406+
patched_reset_cache.assert_called_once()
414407

415408
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
416409
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")

0 commit comments

Comments
 (0)