@@ -136,33 +136,19 @@ def test_jumpstart_cache_get_header():
136
136
137
137
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
138
138
139
- assert (
140
- JumpStartModelHeader (
141
- {
142
- "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
143
- "version" : "2.0.0" ,
144
- "min_version" : "2.49.0" ,
145
- "spec_key" : "community_models_specs/tensorflow-ic"
146
- "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
147
- }
148
- )
149
- == cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
139
+ assert JumpStartModelHeader (
140
+ {
141
+ "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
142
+ "version" : "2.0.0" ,
143
+ "min_version" : "2.49.0" ,
144
+ "spec_key" : "community_models_specs/tensorflow-ic"
145
+ "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
146
+ }
147
+ ) == cache .get_header (
148
+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
150
149
)
151
150
152
151
# See if we can make the same query 2 times consecutively
153
- assert (
154
- JumpStartModelHeader (
155
- {
156
- "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
157
- "version" : "2.0.0" ,
158
- "min_version" : "2.49.0" ,
159
- "spec_key" : "community_models_specs/tensorflow-ic"
160
- "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
161
- }
162
- )
163
- == cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
164
- )
165
-
166
152
assert JumpStartModelHeader (
167
153
{
168
154
"model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
@@ -278,6 +264,7 @@ def test_jumpstart_cache_get_header():
278
264
with pytest .raises (KeyError ):
279
265
cache .get_header (
280
266
model_id = "tensorflow-ic-imagenet-inception-v3-classification-4-bak" ,
267
+ semantic_version_str = "*" ,
281
268
)
282
269
283
270
@@ -340,21 +327,30 @@ def test_jumpstart_cache_handles_boto3_client_errors():
340
327
stubbed_s3_client .add_client_error ("get_object" , http_status_code = 404 )
341
328
stubbed_s3_client .activate ()
342
329
with pytest .raises (botocore .exceptions .ClientError ):
343
- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
330
+ cache .get_header (
331
+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
332
+ semantic_version_str = "*" ,
333
+ )
344
334
345
335
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
346
336
stubbed_s3_client = Stubber (cache ._s3_client )
347
337
stubbed_s3_client .add_client_error ("get_object" , service_error_code = "AccessDenied" )
348
338
stubbed_s3_client .activate ()
349
339
with pytest .raises (botocore .exceptions .ClientError ):
350
- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
340
+ cache .get_header (
341
+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
342
+ semantic_version_str = "*" ,
343
+ )
351
344
352
345
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
353
346
stubbed_s3_client = Stubber (cache ._s3_client )
354
347
stubbed_s3_client .add_client_error ("get_object" , service_error_code = "EndpointConnectionError" )
355
348
stubbed_s3_client .activate ()
356
349
with pytest .raises (botocore .exceptions .ClientError ):
357
- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
350
+ cache .get_header (
351
+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
352
+ semantic_version_str = "*" ,
353
+ )
358
354
359
355
# Testing head_object:
360
356
mock_now = datetime .datetime .fromtimestamp (1636730651.079551 )
@@ -388,13 +384,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
388
384
389
385
stubbed_s3_client1 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
390
386
stubbed_s3_client1 .activate ()
391
- cache1 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
387
+ cache1 .get_header (
388
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
389
+ )
392
390
393
391
mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
394
392
395
393
stubbed_s3_client1 .add_client_error ("head_object" , http_status_code = 404 )
396
394
with pytest .raises (botocore .exceptions .ClientError ):
397
- cache1 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
395
+ cache1 .get_header (
396
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
397
+ semantic_version_str = "*" ,
398
+ )
398
399
399
400
cache2 = JumpStartModelsCache (
400
401
s3_bucket_name = "some_bucket" , s3_cache_expiration_horizon = datetime .timedelta (hours = 1 )
@@ -403,13 +404,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
403
404
404
405
stubbed_s3_client2 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
405
406
stubbed_s3_client2 .activate ()
406
- cache2 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
407
+ cache2 .get_header (
408
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
409
+ )
407
410
408
411
mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
409
412
410
413
stubbed_s3_client2 .add_client_error ("head_object" , service_error_code = "AccessDenied" )
411
414
with pytest .raises (botocore .exceptions .ClientError ):
412
- cache2 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
415
+ cache2 .get_header (
416
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
417
+ semantic_version_str = "*" ,
418
+ )
413
419
414
420
cache3 = JumpStartModelsCache (
415
421
s3_bucket_name = "some_bucket" , s3_cache_expiration_horizon = datetime .timedelta (hours = 1 )
@@ -418,15 +424,20 @@ def test_jumpstart_cache_handles_boto3_client_errors():
418
424
419
425
stubbed_s3_client3 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
420
426
stubbed_s3_client3 .activate ()
421
- cache3 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
427
+ cache3 .get_header (
428
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
429
+ )
422
430
423
431
mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
424
432
425
433
stubbed_s3_client3 .add_client_error (
426
434
"head_object" , service_error_code = "EndpointConnectionError"
427
435
)
428
436
with pytest .raises (botocore .exceptions .ClientError ):
429
- cache3 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
437
+ cache3 .get_header (
438
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
439
+ semantic_version_str = "*" ,
440
+ )
430
441
431
442
432
443
def test_jumpstart_cache_accepts_input_parameters ():
@@ -497,7 +508,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
497
508
}
498
509
mock_boto3_client .return_value .head_object .return_value = {"ETag" : "hash1" }
499
510
500
- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
511
+ cache .get_header (
512
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
513
+ )
501
514
502
515
# first time accessing cache should just involve get_object
503
516
mock_boto3_client .return_value .get_object .assert_called_with (
@@ -520,7 +533,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
520
533
# invalidate cache
521
534
mock_datetime .now .return_value += datetime .timedelta (hours = 2 )
522
535
523
- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
536
+ cache .get_header (
537
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
538
+ )
524
539
525
540
mock_boto3_client .return_value .head_object .assert_called_with (
526
541
Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -542,7 +557,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
542
557
# invalidate cache
543
558
mock_datetime .now .return_value += datetime .timedelta (hours = 2 )
544
559
545
- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
560
+ cache .get_header (
561
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
562
+ )
546
563
547
564
mock_boto3_client .return_value .get_object .assert_called_with (
548
565
Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -581,7 +598,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
581
598
cache = JumpStartModelsCache (
582
599
s3_bucket_name = bucket_name , s3_client_config = client_config , region = "my_region"
583
600
)
584
- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
601
+ cache .get_header (
602
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
603
+ )
585
604
586
605
mock_boto3_client .return_value .get_object .assert_called_with (
587
606
Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -601,7 +620,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
601
620
),
602
621
"ETag" : "etag" ,
603
622
}
604
- cache .get_specs (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
623
+ cache .get_specs (
624
+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
625
+ )
605
626
606
627
mock_boto3_client .return_value .get_object .assert_called_with (
607
628
Bucket = bucket_name ,
@@ -633,7 +654,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
633
654
"imagenet-inception-v3-classification-4/specs_v1.0.0.json" ,
634
655
}
635
656
) == cache .get_header (
636
- model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
657
+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
637
658
)
638
659
cache .clear .assert_called_once ()
639
660
cache .clear .reset_mock ()
@@ -649,6 +670,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
649
670
with pytest .raises (KeyError ):
650
671
cache .get_header (
651
672
model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
673
+ semantic_version_str = "*" ,
652
674
)
653
675
cache .clear .assert_called_once ()
654
676
@@ -683,9 +705,7 @@ def test_jumpstart_cache_get_specs():
683
705
)
684
706
685
707
with pytest .raises (KeyError ):
686
- cache .get_specs (
687
- model_id = model_id + "bak" ,
688
- )
708
+ cache .get_specs (model_id = model_id + "bak" , semantic_version_str = "*" )
689
709
690
710
with pytest .raises (KeyError ):
691
711
cache .get_specs (model_id = model_id , semantic_version_str = "9.*" )
0 commit comments