18
18
from sagemaker .jumpstart .enums import JumpStartScriptScope
19
19
from sagemaker .jumpstart .curated_hub import utils
20
20
from unittest .mock import patch
21
- from sagemaker .jumpstart .curated_hub .types import (
22
- CuratedHubUnsupportedFlag ,
23
- HubContentSummary
24
- )
21
+ from sagemaker .jumpstart .curated_hub .types import CuratedHubUnsupportedFlag , HubContentSummary
25
22
from sagemaker .jumpstart .types import HubContentType
26
23
27
24
@@ -211,10 +208,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
211
208
mock_spec_util .return_value = mock_specs
212
209
213
210
tags = utils .find_unsupported_flags_for_model_version (
214
- model_id = "test" ,
215
- version = "test" ,
216
- region = "test" ,
217
- session = mock_sagemaker_session
211
+ model_id = "test" , version = "test" , region = "test" , session = mock_sagemaker_session
218
212
)
219
213
220
214
mock_spec_util .assert_called_once_with (
@@ -230,7 +224,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
230
224
assert tags == [
231
225
CuratedHubUnsupportedFlag .DEPRECATED_VERSIONS ,
232
226
CuratedHubUnsupportedFlag .INFERENCE_VULNERABLE_VERSIONS ,
233
- CuratedHubUnsupportedFlag .TRAINING_VULNERABLE_VERSIONS
227
+ CuratedHubUnsupportedFlag .TRAINING_VULNERABLE_VERSIONS ,
234
228
]
235
229
236
230
@@ -244,10 +238,7 @@ def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
244
238
mock_spec_util .return_value = mock_specs
245
239
246
240
tags = utils .find_unsupported_flags_for_model_version (
247
- model_id = "test" ,
248
- version = "test" ,
249
- region = "test" ,
250
- session = mock_sagemaker_session
241
+ model_id = "test" , version = "test" , region = "test" , session = mock_sagemaker_session
251
242
)
252
243
253
244
mock_spec_util .assert_called_once_with (
@@ -273,10 +264,7 @@ def test_find_tags_for_jumpstart_model_version_all_false(mock_spec_util):
273
264
mock_spec_util .return_value = mock_specs
274
265
275
266
tags = utils .find_unsupported_flags_for_model_version (
276
- model_id = "test" ,
277
- version = "test" ,
278
- region = "test" ,
279
- session = mock_sagemaker_session
267
+ model_id = "test" , version = "test" , region = "test" , session = mock_sagemaker_session
280
268
)
281
269
282
270
mock_spec_util .assert_called_once_with (
@@ -302,19 +290,16 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
302
290
"HubContentSearchKeywords" : [
303
291
"@jumpstart-model-id:model-one-pytorch" ,
304
292
"@jumpstart-model-version:1.0.3" ,
305
- ]
293
+ ],
306
294
},
307
295
{
308
296
"HubContentVersion" : "2.0.0" ,
309
297
"HubContentSearchKeywords" : [
310
298
"@jumpstart-model-id:model-four-huggingface" ,
311
299
"@jumpstart-model-version:2.0.2" ,
312
- ]
300
+ ],
313
301
},
314
- {
315
- "HubContentVersion" : "3.0.0" ,
316
- "HubContentSearchKeywords" : []
317
- }
302
+ {"HubContentVersion" : "3.0.0" , "HubContentSearchKeywords" : []},
318
303
]
319
304
}
320
305
@@ -325,22 +310,28 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
325
310
mock_spec_util .return_value = mock_specs
326
311
327
312
tags = utils .find_deprecated_vulnerable_flags_for_hub_content (
328
- hub_name = "test" ,
329
- hub_content_name = "test" ,
330
- region = "test" ,
331
- session = mock_sagemaker_session
313
+ hub_name = "test" , hub_content_name = "test" , region = "test" , session = mock_sagemaker_session
332
314
)
333
315
334
316
mock_sagemaker_session .list_hub_content_versions .assert_called_once_with (
335
317
hub_name = "test" ,
336
- hub_content_type = ' Model' ,
318
+ hub_content_type = " Model" ,
337
319
hub_content_name = "test" ,
338
320
)
339
321
340
322
assert tags == [
341
- {"Key" : CuratedHubUnsupportedFlag .DEPRECATED_VERSIONS .value , "Value" : str (["1.0.0" , "2.0.0" ])},
342
- {"Key" : CuratedHubUnsupportedFlag .INFERENCE_VULNERABLE_VERSIONS .value , "Value" : str (["1.0.0" , "2.0.0" ])},
343
- {"Key" : CuratedHubUnsupportedFlag .TRAINING_VULNERABLE_VERSIONS .value , "Value" : str (["1.0.0" , "2.0.0" ])}
323
+ {
324
+ "Key" : CuratedHubUnsupportedFlag .DEPRECATED_VERSIONS .value ,
325
+ "Value" : str (["1.0.0" , "2.0.0" ]),
326
+ },
327
+ {
328
+ "Key" : CuratedHubUnsupportedFlag .INFERENCE_VULNERABLE_VERSIONS .value ,
329
+ "Value" : str (["1.0.0" , "2.0.0" ]),
330
+ },
331
+ {
332
+ "Key" : CuratedHubUnsupportedFlag .TRAINING_VULNERABLE_VERSIONS .value ,
333
+ "Value" : str (["1.0.0" , "2.0.0" ]),
334
+ },
344
335
]
345
336
346
337
@@ -356,7 +347,7 @@ def test_summary_from_list_api_response(mock_spec_util):
356
347
"HubContentStatus" : "test" ,
357
348
"HubContentDescription" : "test_description" ,
358
349
"HubContentSearchKeywords" : ["test" ],
359
- "CreationTime" : "test_creation"
350
+ "CreationTime" : "test_creation" ,
360
351
}
361
352
)
362
353
@@ -375,32 +366,33 @@ def test_summary_from_list_api_response(mock_spec_util):
375
366
376
367
@patch ("sagemaker.jumpstart.utils.verify_model_region_and_return_specs" )
377
368
def test_summaries_from_list_api_response (mock_spec_util ):
378
- test = utils .summary_list_from_list_api_response ({
379
- "HubContentSummaries" : [
380
- {
381
- "HubContentArn" : "test" ,
382
- "HubContentName" : "test" ,
383
- "HubContentVersion" : "test" ,
384
- "HubContentType" : "Model" ,
385
- "DocumentSchemaVersion" : "test" ,
386
- "HubContentStatus" : "test" ,
387
- "HubContentDescription" : "test" ,
388
- "HubContentSearchKeywords" : ["test" , "test_2" ],
389
- "CreationTime" : "test"
390
- },
391
- {
392
- "HubContentArn" : "test_2" ,
393
- "HubContentName" : "test_2" ,
394
- "HubContentVersion" : "test_2" ,
395
- "HubContentType" : "Model" ,
396
- "DocumentSchemaVersion" : "test_2" ,
397
- "HubContentStatus" : "test_2" ,
398
- "HubContentDescription" : "test_2" ,
399
- "HubContentSearchKeywords" : ["test_2" , "test_2_2" ],
400
- "CreationTime" : "test_2"
401
- }
402
- ]
403
- }
369
+ test = utils .summary_list_from_list_api_response (
370
+ {
371
+ "HubContentSummaries" : [
372
+ {
373
+ "HubContentArn" : "test" ,
374
+ "HubContentName" : "test" ,
375
+ "HubContentVersion" : "test" ,
376
+ "HubContentType" : "Model" ,
377
+ "DocumentSchemaVersion" : "test" ,
378
+ "HubContentStatus" : "test" ,
379
+ "HubContentDescription" : "test" ,
380
+ "HubContentSearchKeywords" : ["test" , "test_2" ],
381
+ "CreationTime" : "test" ,
382
+ },
383
+ {
384
+ "HubContentArn" : "test_2" ,
385
+ "HubContentName" : "test_2" ,
386
+ "HubContentVersion" : "test_2" ,
387
+ "HubContentType" : "Model" ,
388
+ "DocumentSchemaVersion" : "test_2" ,
389
+ "HubContentStatus" : "test_2" ,
390
+ "HubContentDescription" : "test_2" ,
391
+ "HubContentSearchKeywords" : ["test_2" , "test_2_2" ],
392
+ "CreationTime" : "test_2" ,
393
+ },
394
+ ]
395
+ }
404
396
)
405
397
406
398
assert test == [
@@ -425,5 +417,5 @@ def test_summaries_from_list_api_response(mock_spec_util):
425
417
hub_content_status = "test_2" ,
426
418
creation_time = "test_2" ,
427
419
hub_content_search_keywords = ["test_2" , "test_2_2" ],
428
- )
420
+ ),
429
421
]
0 commit comments