Skip to content

fix: JumpStart list models flaky tests #4525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/sagemaker/jumpstart/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
)
tasks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
_, task, _ = extract_framework_task_model(model_id)
tasks.add(task)
Expand Down Expand Up @@ -209,7 +212,10 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
)
frameworks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
framework, _, _ = extract_framework_task_model(model_id)
frameworks.add(framework)
Expand Down Expand Up @@ -244,7 +250,10 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin

scripts: Set[str] = set()
for model_id, version in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
scripts.add(JumpStartScriptScope.INFERENCE)
model_specs = verify_model_region_and_return_specs(
Expand Down Expand Up @@ -337,6 +346,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
region: Optional[str] = None,
list_incomplete_models: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: Optional[JumpStartModelType] = None,
) -> Generator:
"""Generate models for JumpStart, and optionally apply filters to result.

Expand Down Expand Up @@ -370,12 +380,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
s3_client=sagemaker_session.s3_client,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)
models_manifest_list = open_weight_manifest_list + prop_models_manifest_list
models_manifest_list = (
open_weight_manifest_list
if model_type == JumpStartModelType.OPEN_WEIGHTS
else (
prop_models_manifest_list
if model_type == JumpStartModelType.PROPRIETARY
else open_weight_manifest_list + prop_models_manifest_list
)
)

if isinstance(filter, str):
filter = Identity(filter)

manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__)
manifest_keys = set(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a bit concerning that this does not cause any unit tests to fail

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because these keys are exactly the same. I was thinking to add a hypothetical new key in a test, but then I feel it should be done when we actually change the slots in JumpStartModelHeader. Let me know if you feel strongly to add it now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you're say it happens to work now, cause they're the same. i guess that's fine. what about the spec keys though?

open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__
)

all_keys: Set[str] = set()

Expand Down
64 changes: 38 additions & 26 deletions tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_prototype_manifest,
get_prototype_model_spec,
)
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.notebook_utils import (
_generate_jumpstart_model_versions,
Expand All @@ -40,8 +41,8 @@ def test_list_jumpstart_scripts(
patched_read_s3_file: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
Expand All @@ -63,7 +64,9 @@ def test_list_jumpstart_scripts(
}
assert list_jumpstart_scripts(**kwargs) == sorted(["inference", "training"])
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
assert patched_get_model_specs.call_count == 1
Expand All @@ -76,12 +79,15 @@ def test_list_jumpstart_scripts(
"filter": "training_supported is False",
"region": "sa-east-1",
}
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
assert list_jumpstart_scripts(**kwargs) == []
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT)
assert patched_read_s3_file.call_count == num_specs


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
Expand All @@ -93,8 +99,8 @@ def test_list_jumpstart_tasks(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions

Expand Down Expand Up @@ -122,7 +128,9 @@ def test_list_jumpstart_tasks(
}
assert list_jumpstart_tasks(**kwargs) == ["ic"]
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
patched_get_model_specs.assert_not_called()
Expand All @@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions

Expand Down Expand Up @@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks(
)

patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 4
patched_get_model_specs.assert_not_called()
Expand Down Expand Up @@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter(
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
)
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

manifest_length = len(get_prototype_manifest())
Expand Down Expand Up @@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models(
patched_get_manifest: Mock,
):

patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str:
Expand All @@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_inference_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models(
And("inference_vulnerable is false", "training_vulnerable is false")
)

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
And("inference_vulnerable is false", "training_vulnerable is false")
)

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models(
patched_get_manifest: Mock,
):

patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
Expand All @@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
patched_read_s3_file.side_effect = deprecated_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models("deprecated equals false")

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand Down Expand Up @@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries(
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
)
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

assert list_jumpstart_models(
Expand Down Expand Up @@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

with pytest.raises(NotImplementedError):
Expand All @@ -730,8 +742,8 @@ def test_get_model_url(

patched_get_model_specs.side_effect = get_prototype_model_spec
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

model_id, version = "xgboost-classification-model", "1.0.0"
Expand Down