Skip to content

Commit b32ac3c

Browse files
committed
fix list models flaky tests
1 parent 327638e commit b32ac3c

File tree

2 files changed

+61
-31
lines changed

2 files changed

+61
-31
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
178178
)
179179
tasks: Set[str] = set()
180180
for model_id, _ in _generate_jumpstart_model_versions(
181-
filter=filter, region=region, sagemaker_session=sagemaker_session
181+
filter=filter,
182+
region=region,
183+
sagemaker_session=sagemaker_session,
184+
model_type=JumpStartModelType.OPEN_WEIGHTS,
182185
):
183186
_, task, _ = extract_framework_task_model(model_id)
184187
tasks.add(task)
@@ -209,7 +212,10 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
209212
)
210213
frameworks: Set[str] = set()
211214
for model_id, _ in _generate_jumpstart_model_versions(
212-
filter=filter, region=region, sagemaker_session=sagemaker_session
215+
filter=filter,
216+
region=region,
217+
sagemaker_session=sagemaker_session,
218+
model_type=JumpStartModelType.OPEN_WEIGHTS,
213219
):
214220
framework, _, _ = extract_framework_task_model(model_id)
215221
frameworks.add(framework)
@@ -244,7 +250,10 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
244250

245251
scripts: Set[str] = set()
246252
for model_id, version in _generate_jumpstart_model_versions(
247-
filter=filter, region=region, sagemaker_session=sagemaker_session
253+
filter=filter,
254+
region=region,
255+
sagemaker_session=sagemaker_session,
256+
model_type=JumpStartModelType.OPEN_WEIGHTS,
248257
):
249258
scripts.add(JumpStartScriptScope.INFERENCE)
250259
model_specs = verify_model_region_and_return_specs(
@@ -337,6 +346,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
337346
region: Optional[str] = None,
338347
list_incomplete_models: bool = False,
339348
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
349+
model_type: Optional[JumpStartModelType] = None,
340350
) -> Generator:
341351
"""Generate models for JumpStart, and optionally apply filters to result.
342352
@@ -370,12 +380,20 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
370380
s3_client=sagemaker_session.s3_client,
371381
model_type=JumpStartModelType.OPEN_WEIGHTS,
372382
)
373-
models_manifest_list = open_weight_manifest_list + prop_models_manifest_list
383+
models_manifest_list = (
384+
open_weight_manifest_list
385+
if model_type == JumpStartModelType.OPEN_WEIGHTS
386+
else (
387+
prop_models_manifest_list
388+
if model_type == JumpStartModelType.PROPRIETARY
389+
else open_weight_manifest_list + prop_models_manifest_list
390+
)
391+
)
374392

375393
if isinstance(filter, str):
376394
filter = Identity(filter)
377395

378-
manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__)
396+
manifest_keys = set(models_manifest_list[0].__slots__)
379397

380398
all_keys: Set[str] = set()
381399

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_prototype_manifest,
1818
get_prototype_model_spec,
1919
)
20+
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
2021
from sagemaker.jumpstart.enums import JumpStartModelType
2122
from sagemaker.jumpstart.notebook_utils import (
2223
_generate_jumpstart_model_versions,
@@ -40,8 +41,8 @@ def test_list_jumpstart_scripts(
4041
patched_read_s3_file: Mock,
4142
):
4243
patched_get_model_specs.side_effect = get_prototype_model_spec
43-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
44-
region
44+
patched_get_manifest.side_effect = (
45+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
4546
)
4647
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions
4748
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
@@ -63,7 +64,9 @@ def test_list_jumpstart_scripts(
6364
}
6465
assert list_jumpstart_scripts(**kwargs) == sorted(["inference", "training"])
6566
patched_generate_jumpstart_models.assert_called_once_with(
66-
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
67+
**kwargs,
68+
model_type=JumpStartModelType.OPEN_WEIGHTS,
69+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6770
)
6871
assert patched_get_manifest.call_count == 2
6972
assert patched_get_model_specs.call_count == 1
@@ -76,12 +79,15 @@ def test_list_jumpstart_scripts(
7679
"filter": "training_supported is False",
7780
"region": "sa-east-1",
7881
}
82+
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
7983
assert list_jumpstart_scripts(**kwargs) == []
8084
patched_generate_jumpstart_models.assert_called_once_with(
81-
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
85+
**kwargs,
86+
model_type=JumpStartModelType.OPEN_WEIGHTS,
87+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8288
)
8389
assert patched_get_manifest.call_count == 2
84-
assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT)
90+
assert patched_read_s3_file.call_count == num_specs
8591

8692

8793
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@@ -93,8 +99,8 @@ def test_list_jumpstart_tasks(
9399
patched_get_manifest: Mock,
94100
):
95101
patched_get_model_specs.side_effect = get_prototype_model_spec
96-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
97-
region
102+
patched_get_manifest.side_effect = (
103+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
98104
)
99105
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions
100106

@@ -122,7 +128,9 @@ def test_list_jumpstart_tasks(
122128
}
123129
assert list_jumpstart_tasks(**kwargs) == ["ic"]
124130
patched_generate_jumpstart_models.assert_called_once_with(
125-
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
131+
**kwargs,
132+
model_type=JumpStartModelType.OPEN_WEIGHTS,
133+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126134
)
127135
assert patched_get_manifest.call_count == 2
128136
patched_get_model_specs.assert_not_called()
@@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks(
137145
patched_get_manifest: Mock,
138146
):
139147
patched_get_model_specs.side_effect = get_prototype_model_spec
140-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
141-
region
148+
patched_get_manifest.side_effect = (
149+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
142150
)
143151
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions
144152

@@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks(
180188
)
181189

182190
patched_generate_jumpstart_models.assert_called_once_with(
183-
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
191+
**kwargs,
192+
model_type=JumpStartModelType.OPEN_WEIGHTS,
193+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
184194
)
185195
assert patched_get_manifest.call_count == 4
186196
patched_get_model_specs.assert_not_called()
@@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter(
229239
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
230240
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
231241
)
232-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
233-
region
242+
patched_get_manifest.side_effect = (
243+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
234244
)
235245

236246
manifest_length = len(get_prototype_manifest())
@@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models(
516526
patched_get_manifest: Mock,
517527
):
518528

519-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
520-
region
529+
patched_get_manifest.side_effect = (
530+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
521531
)
522532

523533
def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str:
@@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
533543
patched_read_s3_file.side_effect = vulnerable_inference_model_spec
534544

535545
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
546+
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
536547
assert [] == list_jumpstart_models(
537548
And("inference_vulnerable is false", "training_vulnerable is false")
538549
)
539550

540-
assert patched_read_s3_file.call_count == 2 * num_specs
551+
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
541552
assert patched_get_manifest.call_count == 2
542553

543554
patched_get_manifest.reset_mock()
@@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
549560
And("inference_vulnerable is false", "training_vulnerable is false")
550561
)
551562

552-
assert patched_read_s3_file.call_count == 2 * num_specs
563+
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
553564
assert patched_get_manifest.call_count == 2
554565

555566
patched_get_manifest.reset_mock()
@@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models(
567578
patched_get_manifest: Mock,
568579
):
569580

570-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
571-
region
581+
patched_get_manifest.side_effect = (
582+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
572583
)
573584

574585
def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
@@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
579590
patched_read_s3_file.side_effect = deprecated_model_spec
580591

581592
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
593+
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
582594
assert [] == list_jumpstart_models("deprecated equals false")
583595

584-
assert patched_read_s3_file.call_count == 2 * num_specs
596+
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
585597
assert patched_get_manifest.call_count == 2
586598

587599
patched_get_manifest.reset_mock()
@@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries(
666678
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
667679
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
668680
)
669-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
670-
region
681+
patched_get_manifest.side_effect = (
682+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
671683
)
672684

673685
assert list_jumpstart_models(
@@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index(
711723
patched_get_manifest: Mock,
712724
):
713725
patched_get_model_specs.side_effect = get_prototype_model_spec
714-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
715-
region
726+
patched_get_manifest.side_effect = (
727+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
716728
)
717729

718730
with pytest.raises(NotImplementedError):
@@ -730,8 +742,8 @@ def test_get_model_url(
730742

731743
patched_get_model_specs.side_effect = get_prototype_model_spec
732744
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
733-
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
734-
region
745+
patched_get_manifest.side_effect = (
746+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
735747
)
736748

737749
model_id, version = "xgboost-classification-model", "1.0.0"

0 commit comments

Comments
 (0)