17
17
get_prototype_manifest ,
18
18
get_prototype_model_spec ,
19
19
)
20
+ from tests .unit .sagemaker .jumpstart .constants import BASE_PROPRIETARY_MANIFEST
20
21
from sagemaker .jumpstart .enums import JumpStartModelType
21
22
from sagemaker .jumpstart .notebook_utils import (
22
23
_generate_jumpstart_model_versions ,
@@ -40,8 +41,8 @@ def test_list_jumpstart_scripts(
40
41
patched_read_s3_file : Mock ,
41
42
):
42
43
patched_get_model_specs .side_effect = get_prototype_model_spec
43
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
44
- region
44
+ patched_get_manifest .side_effect = (
45
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
45
46
)
46
47
patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
47
48
patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
@@ -63,7 +64,9 @@ def test_list_jumpstart_scripts(
63
64
}
64
65
assert list_jumpstart_scripts (** kwargs ) == sorted (["inference" , "training" ])
65
66
patched_generate_jumpstart_models .assert_called_once_with (
66
- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
67
+ ** kwargs ,
68
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
69
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
67
70
)
68
71
assert patched_get_manifest .call_count == 2
69
72
assert patched_get_model_specs .call_count == 1
@@ -76,12 +79,15 @@ def test_list_jumpstart_scripts(
76
79
"filter" : "training_supported is False" ,
77
80
"region" : "sa-east-1" ,
78
81
}
82
+ num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
79
83
assert list_jumpstart_scripts (** kwargs ) == []
80
84
patched_generate_jumpstart_models .assert_called_once_with (
81
- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
85
+ ** kwargs ,
86
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
87
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
82
88
)
83
89
assert patched_get_manifest .call_count == 2
84
- assert patched_read_s3_file .call_count == 2 * len ( PROTOTYPICAL_MODEL_SPECS_DICT )
90
+ assert patched_read_s3_file .call_count == num_specs
85
91
86
92
87
93
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
@@ -93,8 +99,8 @@ def test_list_jumpstart_tasks(
93
99
patched_get_manifest : Mock ,
94
100
):
95
101
patched_get_model_specs .side_effect = get_prototype_model_spec
96
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
97
- region
102
+ patched_get_manifest .side_effect = (
103
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
98
104
)
99
105
patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
100
106
@@ -122,7 +128,9 @@ def test_list_jumpstart_tasks(
122
128
}
123
129
assert list_jumpstart_tasks (** kwargs ) == ["ic" ]
124
130
patched_generate_jumpstart_models .assert_called_once_with (
125
- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
131
+ ** kwargs ,
132
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
133
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
126
134
)
127
135
assert patched_get_manifest .call_count == 2
128
136
patched_get_model_specs .assert_not_called ()
@@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks(
137
145
patched_get_manifest : Mock ,
138
146
):
139
147
patched_get_model_specs .side_effect = get_prototype_model_spec
140
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
141
- region
148
+ patched_get_manifest .side_effect = (
149
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
142
150
)
143
151
patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
144
152
@@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks(
180
188
)
181
189
182
190
patched_generate_jumpstart_models .assert_called_once_with (
183
- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
191
+ ** kwargs ,
192
+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
193
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
184
194
)
185
195
assert patched_get_manifest .call_count == 4
186
196
patched_get_model_specs .assert_not_called ()
@@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter(
229
239
patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
230
240
get_prototype_model_spec (None , "pytorch-eqa-bert-base-cased" ).to_json ()
231
241
)
232
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
233
- region
242
+ patched_get_manifest .side_effect = (
243
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
234
244
)
235
245
236
246
manifest_length = len (get_prototype_manifest ())
@@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models(
516
526
patched_get_manifest : Mock ,
517
527
):
518
528
519
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
520
- region
529
+ patched_get_manifest .side_effect = (
530
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
521
531
)
522
532
523
533
def vulnerable_inference_model_spec (bucket , key , * args , ** kwargs ) -> str :
@@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
533
543
patched_read_s3_file .side_effect = vulnerable_inference_model_spec
534
544
535
545
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
546
+ num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
536
547
assert [] == list_jumpstart_models (
537
548
And ("inference_vulnerable is false" , "training_vulnerable is false" )
538
549
)
539
550
540
- assert patched_read_s3_file .call_count == 2 * num_specs
551
+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
541
552
assert patched_get_manifest .call_count == 2
542
553
543
554
patched_get_manifest .reset_mock ()
@@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
549
560
And ("inference_vulnerable is false" , "training_vulnerable is false" )
550
561
)
551
562
552
- assert patched_read_s3_file .call_count == 2 * num_specs
563
+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
553
564
assert patched_get_manifest .call_count == 2
554
565
555
566
patched_get_manifest .reset_mock ()
@@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models(
567
578
patched_get_manifest : Mock ,
568
579
):
569
580
570
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
571
- region
581
+ patched_get_manifest .side_effect = (
582
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
572
583
)
573
584
574
585
def deprecated_model_spec (bucket , key , * args , ** kwargs ) -> str :
@@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
579
590
patched_read_s3_file .side_effect = deprecated_model_spec
580
591
581
592
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
593
+ num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
582
594
assert [] == list_jumpstart_models ("deprecated equals false" )
583
595
584
- assert patched_read_s3_file .call_count == 2 * num_specs
596
+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
585
597
assert patched_get_manifest .call_count == 2
586
598
587
599
patched_get_manifest .reset_mock ()
@@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries(
666
678
patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
667
679
get_prototype_model_spec (None , "pytorch-eqa-bert-base-cased" ).to_json ()
668
680
)
669
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
670
- region
681
+ patched_get_manifest .side_effect = (
682
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
671
683
)
672
684
673
685
assert list_jumpstart_models (
@@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index(
711
723
patched_get_manifest : Mock ,
712
724
):
713
725
patched_get_model_specs .side_effect = get_prototype_model_spec
714
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
715
- region
726
+ patched_get_manifest .side_effect = (
727
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
716
728
)
717
729
718
730
with pytest .raises (NotImplementedError ):
@@ -730,8 +742,8 @@ def test_get_model_url(
730
742
731
743
patched_get_model_specs .side_effect = get_prototype_model_spec
732
744
patched_validate_model_id_and_get_type .return_value = JumpStartModelType .OPEN_WEIGHTS
733
- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
734
- region
745
+ patched_get_manifest .side_effect = (
746
+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
735
747
)
736
748
737
749
model_id , version = "xgboost-classification-model" , "1.0.0"
0 commit comments