15
15
from unittest .mock import Mock
16
16
from sagemaker .jumpstart .types import HubArnExtractedInfo
17
17
from sagemaker .jumpstart .constants import JUMPSTART_DEFAULT_REGION_NAME
18
+ from sagemaker .jumpstart .enums import JumpStartScriptScope
18
19
from sagemaker .jumpstart .curated_hub import utils
20
+ from unittest .mock import patch
21
+ from sagemaker .jumpstart .curated_hub .types import (
22
+ Tag ,
23
+ CuratedHubTagName
24
+ )
19
25
20
26
21
27
def test_get_info_from_hub_resource_arn ():
@@ -168,3 +174,110 @@ def test_create_hub_bucket_if_it_does_not_exist():
168
174
169
175
mock_sagemaker_session .boto_session .resource ("s3" ).create_bucketassert_called_once ()
170
176
assert created_hub_bucket_name == bucket_name
177
+
178
+ @patch ("sagemaker.jumpstart.utils.verify_model_region_and_return_specs" )
179
+ def test_find_tags_for_jumpstart_model_version (mock_spec_util ):
180
+ mock_sagemaker_session = Mock ()
181
+ mock_specs = Mock ()
182
+ mock_specs .deprecated = True
183
+ mock_specs .inference_vulnerable = True
184
+ mock_specs .training_vulnerable = True
185
+ mock_spec_util .return_value = mock_specs
186
+
187
+ tags = utils .find_tags_for_jumpstart_model_version (
188
+ model_id = "test" ,
189
+ version = "test" ,
190
+ region = "test" ,
191
+ session = mock_sagemaker_session
192
+ )
193
+
194
+ mock_spec_util .assert_called_once_with (
195
+ model_id = "test" ,
196
+ version = "test" ,
197
+ region = "test" ,
198
+ scope = JumpStartScriptScope .INFERENCE ,
199
+ tolerate_vulnerable_model = True ,
200
+ tolerate_deprecated_model = True ,
201
+ sagemaker_session = mock_sagemaker_session ,
202
+ )
203
+
204
+ assert tags == [CuratedHubTagName .DEPRECATED_VERSIONS_TAG , CuratedHubTagName .INFERENCE_VULNERABLE_VERSIONS_TAG , CuratedHubTagName .TRAINING_VULNERABLE_VERSIONS_TAG ]
205
+
206
+ @patch ("sagemaker.jumpstart.utils.verify_model_region_and_return_specs" )
207
+ def test_find_tags_for_jumpstart_model_version_some_false (mock_spec_util ):
208
+ mock_sagemaker_session = Mock ()
209
+ mock_specs = Mock ()
210
+ mock_specs .deprecated = True
211
+ mock_specs .inference_vulnerable = False
212
+ mock_specs .training_vulnerable = False
213
+ mock_spec_util .return_value = mock_specs
214
+
215
+ tags = utils .find_tags_for_jumpstart_model_version (
216
+ model_id = "test" ,
217
+ version = "test" ,
218
+ region = "test" ,
219
+ session = mock_sagemaker_session
220
+ )
221
+
222
+ mock_spec_util .assert_called_once_with (
223
+ model_id = "test" ,
224
+ version = "test" ,
225
+ region = "test" ,
226
+ scope = JumpStartScriptScope .INFERENCE ,
227
+ tolerate_vulnerable_model = True ,
228
+ tolerate_deprecated_model = True ,
229
+ sagemaker_session = mock_sagemaker_session ,
230
+ )
231
+
232
+ assert tags == [CuratedHubTagName .DEPRECATED_VERSIONS_TAG ]
233
+
234
+ @patch ("sagemaker.jumpstart.utils.verify_model_region_and_return_specs" )
235
+ def test_find_all_tags_for_jumpstart_model (mock_spec_util ):
236
+ mock_sagemaker_session = Mock ()
237
+ mock_sagemaker_session .list_hub_content_versions .return_value = {
238
+ "HubContentSummaries" : [
239
+ {
240
+ "HubContentVersion" : "1.0.0" ,
241
+ "search_keywords" : [
242
+ "@jumpstart-model-id:model-one-pytorch" ,
243
+ "@jumpstart-model-version:1.0.3" ,
244
+ ]
245
+ },
246
+ {
247
+ "HubContentVersion" : "2.0.0" ,
248
+ "search_keywords" : [
249
+ "@jumpstart-model-id:model-four-huggingface" ,
250
+ "@jumpstart-model-version:2.0.2" ,
251
+ ]
252
+ },
253
+ {
254
+ "HubContentVersion" : "3.0.0" ,
255
+ "search_keywords" : []
256
+ }
257
+ ]
258
+ }
259
+
260
+ mock_specs = Mock ()
261
+ mock_specs .deprecated = True
262
+ mock_specs .inference_vulnerable = True
263
+ mock_specs .training_vulnerable = True
264
+ mock_spec_util .return_value = mock_specs
265
+
266
+ tags = utils .find_all_tags_for_jumpstart_model (
267
+ hub_name = "test" ,
268
+ hub_content_name = "test" ,
269
+ region = "test" ,
270
+ session = mock_sagemaker_session
271
+ )
272
+
273
+ mock_sagemaker_session .list_hub_content_versions .assert_called_once_with (
274
+ hub_name = "test" ,
275
+ hub_content_type = 'Model' ,
276
+ hub_content_name = "test" ,
277
+ )
278
+
279
+ assert tags == [
280
+ Tag (key = CuratedHubTagName .DEPRECATED_VERSIONS_TAG , value = str (["1.0.0" , "2.0.0" ])),
281
+ Tag (key = CuratedHubTagName .INFERENCE_VULNERABLE_VERSIONS_TAG , value = str (["1.0.0" , "2.0.0" ])),
282
+ Tag (key = CuratedHubTagName .TRAINING_VULNERABLE_VERSIONS_TAG , value = str (["1.0.0" , "2.0.0" ]))
283
+ ]
0 commit comments