141
141
"MaxParallelOfTests" : 5 ,
142
142
}
143
143
144
+ IR_SAMPLE_PRIMARY_CONTAINER = {
145
+ "Image" : "model-image-for-ir" ,
146
+ "Environment" : {},
147
+ "ModelDataUrl" : "s3://bucket/model.tar.gz" ,
148
+ }
149
+
150
+ IR_PRODUCTION_VARIANTS = [
151
+ {
152
+ "ModelName" : "model-name-for-ir" ,
153
+ "VariantName" : "AllTraffic" ,
154
+ "InitialVariantWeight" : 1 ,
155
+ "InitialInstanceCount" : 1 ,
156
+ "InstanceType" : "ml.m5.xlarge" ,
157
+ }
158
+ ]
159
+
160
+ IR_OVERRIDDEN_PRODUCTION_VARIANTS = [
161
+ {
162
+ "ModelName" : "model-name-for-ir" ,
163
+ "VariantName" : "AllTraffic" ,
164
+ "InitialVariantWeight" : 1 ,
165
+ "InitialInstanceCount" : 5 ,
166
+ "InstanceType" : "ml.c5.2xlarge" ,
167
+ }
168
+ ]
169
+
170
+ IR_SERVERLESS_PRODUCTION_VARIANTS = [
171
+ {
172
+ "ModelName" : "model-name-for-ir" ,
173
+ "VariantName" : "AllTraffic" ,
174
+ "InitialVariantWeight" : 1 ,
175
+ "ServerlessConfig" : {"MemorySizeInMB" : 2048 , "MaxConcurrency" : 5 },
176
+ }
177
+ ]
178
+
144
179
145
180
@pytest .fixture ()
146
181
def sagemaker_session ():
@@ -185,17 +220,17 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
185
220
framework = IR_SAMPLE_FRAMEWORK ,
186
221
)
187
222
188
- assert sagemaker_session .create_model .called_with (
223
+ sagemaker_session .create_model .assert_called_with (
189
224
name = ANY ,
190
225
role = IR_ROLE_ARN ,
191
226
container_defs = None ,
192
- primary_container = {} ,
227
+ primary_container = IR_SAMPLE_PRIMARY_CONTAINER ,
193
228
vpc_config = None ,
194
229
enable_network_isolation = False ,
195
230
)
196
231
197
232
# assert that the create api has been called with default parameters with model name
198
- assert sagemaker_session .create_inference_recommendations_job .called_with (
233
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
199
234
role = IR_ROLE_ARN ,
200
235
job_name = IR_JOB_NAME ,
201
236
job_type = "Default" ,
@@ -213,7 +248,9 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
213
248
resource_limit = None ,
214
249
)
215
250
216
- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
251
+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
252
+ IR_JOB_NAME , log_level = "Verbose"
253
+ )
217
254
218
255
# confirm that the IR instance attributes have been set
219
256
assert (
@@ -232,6 +269,7 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
232
269
@patch ("uuid.uuid4" , MagicMock (return_value = "sample-unique-uuid" ))
233
270
def test_right_size_advanced_list_instances_model_name_successful (sagemaker_session , model ):
234
271
inference_recommender_model = model .right_size (
272
+ job_name = IR_JOB_NAME ,
235
273
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
236
274
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
237
275
framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -246,7 +284,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
246
284
)
247
285
248
286
# assert that the create api has been called with advanced parameters
249
- assert sagemaker_session .create_inference_recommendations_job .called_with (
287
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
250
288
role = IR_ROLE_ARN ,
251
289
job_name = IR_JOB_NAME ,
252
290
job_type = "Advanced" ,
@@ -256,15 +294,17 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
256
294
framework = IR_SAMPLE_FRAMEWORK ,
257
295
framework_version = None ,
258
296
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
259
- supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
260
- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
297
+ supported_content_types = [ "text/csv" ] ,
298
+ supported_instance_types = None ,
261
299
endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
262
300
traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
263
301
stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
264
302
resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
265
303
)
266
304
267
- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
305
+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
306
+ IR_JOB_NAME , log_level = "Verbose"
307
+ )
268
308
269
309
# confirm that the IR instance attributes have been set
270
310
assert (
@@ -283,6 +323,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
283
323
@patch ("uuid.uuid4" , MagicMock (return_value = "sample-unique-uuid" ))
284
324
def test_right_size_advanced_single_instances_model_name_successful (sagemaker_session , model ):
285
325
model .right_size (
326
+ job_name = IR_JOB_NAME ,
286
327
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
287
328
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
288
329
framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -297,7 +338,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
297
338
)
298
339
299
340
# assert that the create api has been called with advanced parameters
300
- assert sagemaker_session .create_inference_recommendations_job .called_with (
341
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
301
342
role = IR_ROLE_ARN ,
302
343
job_name = IR_JOB_NAME ,
303
344
job_type = "Advanced" ,
@@ -308,7 +349,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
308
349
framework_version = None ,
309
350
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
310
351
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
311
- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
352
+ supported_instance_types = None ,
312
353
endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
313
354
traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
314
355
stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -326,7 +367,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
326
367
)
327
368
328
369
# assert that the create api has been called with default parameters
329
- assert sagemaker_session .create_inference_recommendations_job .called_with (
370
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
330
371
role = IR_ROLE_ARN ,
331
372
job_name = IR_JOB_NAME ,
332
373
job_type = "Default" ,
@@ -344,7 +385,9 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
344
385
resource_limit = None ,
345
386
)
346
387
347
- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
388
+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
389
+ IR_JOB_NAME , log_level = "Verbose"
390
+ )
348
391
349
392
# confirm that the IR instance attributes have been set
350
393
assert (
@@ -364,6 +407,7 @@ def test_right_size_advanced_list_instances_model_package_successful(
364
407
sagemaker_session , model_package
365
408
):
366
409
inference_recommender_model_pkg = model_package .right_size (
410
+ job_name = IR_JOB_NAME ,
367
411
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
368
412
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
369
413
framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -378,24 +422,27 @@ def test_right_size_advanced_list_instances_model_package_successful(
378
422
)
379
423
380
424
# assert that the create api has been called with advanced parameters
381
- assert sagemaker_session .create_inference_recommendations_job .called_with (
425
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
382
426
role = IR_ROLE_ARN ,
383
427
job_name = IR_JOB_NAME ,
384
428
job_type = "Advanced" ,
385
429
job_duration_in_seconds = 7200 ,
430
+ model_name = None ,
386
431
model_package_version_arn = model_package .model_package_arn ,
387
432
framework = IR_SAMPLE_FRAMEWORK ,
388
433
framework_version = None ,
389
434
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
390
435
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
391
- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
436
+ supported_instance_types = None ,
392
437
endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
393
438
traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
394
439
stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
395
440
resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
396
441
)
397
442
398
- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
443
+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
444
+ IR_JOB_NAME , log_level = "Verbose"
445
+ )
399
446
400
447
# confirm that the IR instance attributes have been set
401
448
assert (
@@ -415,6 +462,7 @@ def test_right_size_advanced_single_instances_model_package_successful(
415
462
sagemaker_session , model_package
416
463
):
417
464
model_package .right_size (
465
+ job_name = IR_JOB_NAME ,
418
466
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
419
467
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
420
468
framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -429,17 +477,18 @@ def test_right_size_advanced_single_instances_model_package_successful(
429
477
)
430
478
431
479
# assert that the create api has been called with advanced parameters
432
- assert sagemaker_session .create_inference_recommendations_job .called_with (
480
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
433
481
role = IR_ROLE_ARN ,
434
482
job_name = IR_JOB_NAME ,
435
483
job_type = "Advanced" ,
436
484
job_duration_in_seconds = 7200 ,
485
+ model_name = None ,
437
486
model_package_version_arn = model_package .model_package_arn ,
438
487
framework = IR_SAMPLE_FRAMEWORK ,
439
488
framework_version = None ,
440
489
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
441
490
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
442
- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
491
+ supported_instance_types = None ,
443
492
endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
444
493
traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
445
494
stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -451,6 +500,7 @@ def test_right_size_advanced_model_package_partial_params_successful(
451
500
sagemaker_session , model_package
452
501
):
453
502
model_package .right_size (
503
+ job_name = IR_JOB_NAME ,
454
504
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
455
505
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
456
506
framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -463,17 +513,18 @@ def test_right_size_advanced_model_package_partial_params_successful(
463
513
)
464
514
465
515
# assert that the create api has been called with advanced parameters
466
- assert sagemaker_session .create_inference_recommendations_job .called_with (
516
+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
467
517
role = IR_ROLE_ARN ,
468
518
job_name = IR_JOB_NAME ,
469
519
job_type = "Advanced" ,
470
520
job_duration_in_seconds = 7200 ,
521
+ model_name = None ,
471
522
model_package_version_arn = model_package .model_package_arn ,
472
523
framework = IR_SAMPLE_FRAMEWORK ,
473
524
framework_version = None ,
474
525
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
475
526
supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
476
- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
527
+ supported_instance_types = None ,
477
528
endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
478
529
traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
479
530
stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -504,45 +555,42 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa
504
555
# TODO check our framework mapping when we add in inference_recommendation_id support
505
556
506
557
507
- @patch ("sagemaker.production_variant" )
508
- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
509
558
def test_deploy_right_size_with_model_package_succeeds (
510
- production_variant , default_right_sized_model
559
+ sagemaker_session , default_right_sized_model
511
560
):
561
+
562
+ default_right_sized_model .name = MODEL_NAME
512
563
default_right_sized_model .deploy (endpoint_name = IR_DEPLOY_ENDPOINT_NAME )
513
564
514
- assert production_variant .called_with (
515
- model_name = MODEL_NAME ,
516
- instance_type = IR_RIGHT_SIZE_INSTANCE_TYPE ,
517
- initial_instance_count = IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT ,
518
- accelerator_type = None ,
519
- serverless_inference_config = None ,
520
- volume_size = None ,
521
- model_data_download_timeout = None ,
522
- container_startup_health_check_timeout = None ,
565
+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
566
+ async_inference_config_dict = None ,
567
+ data_capture_config_dict = None ,
568
+ kms_key = None ,
569
+ name = "ir-endpoint-test" ,
570
+ production_variants = IR_PRODUCTION_VARIANTS ,
571
+ tags = None ,
572
+ wait = True ,
523
573
)
524
574
525
575
526
- @patch ("sagemaker.production_variant" )
527
- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
528
576
def test_deploy_right_size_with_both_overrides_succeeds (
529
- production_variant , default_right_sized_model
577
+ sagemaker_session , default_right_sized_model
530
578
):
579
+ default_right_sized_model .name = MODEL_NAME
531
580
default_right_sized_model .deploy (
532
581
instance_type = "ml.c5.2xlarge" ,
533
582
initial_instance_count = 5 ,
534
583
endpoint_name = IR_DEPLOY_ENDPOINT_NAME ,
535
584
)
536
585
537
- assert production_variant .called_with (
538
- model_name = MODEL_NAME ,
539
- instance_type = "ml.c5.2xlarge" ,
540
- initial_instance_count = 5 ,
541
- accelerator_type = None ,
542
- serverless_inference_config = None ,
543
- volume_size = None ,
544
- model_data_download_timeout = None ,
545
- container_startup_health_check_timeout = None ,
586
+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
587
+ async_inference_config_dict = None ,
588
+ data_capture_config_dict = None ,
589
+ kms_key = None ,
590
+ name = "ir-endpoint-test" ,
591
+ production_variants = IR_OVERRIDDEN_PRODUCTION_VARIANTS ,
592
+ tags = None ,
593
+ wait = True ,
546
594
)
547
595
548
596
@@ -576,41 +624,41 @@ def test_deploy_right_size_accelerator_type_fails(default_right_sized_model):
576
624
default_right_sized_model .deploy (accelerator_type = "ml.eia.medium" )
577
625
578
626
579
- @patch ("sagemaker.production_variant" )
580
- @ patch ( "sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
581
- def test_deploy_right_size_serverless_override ( production_variant , default_right_sized_model ):
627
+ @patch ("sagemaker.utils.name_from_base" , MagicMock ( return_value = MODEL_NAME ) )
628
+ def test_deploy_right_size_serverless_override ( sagemaker_session , default_right_sized_model ):
629
+ default_right_sized_model . name = MODEL_NAME
582
630
serverless_inference_config = ServerlessInferenceConfig ()
583
631
default_right_sized_model .deploy (serverless_inference_config = serverless_inference_config )
584
632
585
- assert production_variant .called_with (
586
- model_name = MODEL_NAME ,
587
- instance_type = None ,
588
- initial_instance_count = None ,
589
- accelerator_type = None ,
590
- serverless_inference_config = serverless_inference_config ._to_request_dict ,
591
- volume_size = None ,
592
- model_data_download_timeout = None ,
593
- container_startup_health_check_timeout = None ,
633
+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
634
+ name = MODEL_NAME ,
635
+ production_variants = IR_SERVERLESS_PRODUCTION_VARIANTS ,
636
+ tags = None ,
637
+ kms_key = None ,
638
+ wait = True ,
639
+ data_capture_config_dict = None ,
640
+ async_inference_config_dict = None ,
594
641
)
595
642
596
643
597
- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
644
+ @patch ("sagemaker.utils.name_from_base" , MagicMock ( return_value = MODEL_NAME ) )
598
645
def test_deploy_right_size_async_override (sagemaker_session , default_right_sized_model ):
646
+ default_right_sized_model .name = MODEL_NAME
599
647
async_inference_config = AsyncInferenceConfig (output_path = "s3://some-path" )
600
648
default_right_sized_model .deploy (
601
649
instance_type = "ml.c5.2xlarge" ,
602
650
initial_instance_count = 1 ,
603
651
async_inference_config = async_inference_config ,
604
652
)
605
653
606
- assert sagemaker_session .endpoint_from_production_variants .called_with (
654
+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
607
655
name = MODEL_NAME ,
608
656
production_variants = [ANY ],
609
657
tags = None ,
610
658
kms_key = None ,
611
- wait = None ,
659
+ wait = True ,
612
660
data_capture_config_dict = None ,
613
- async_inference_config_dict = async_inference_config . _to_request_dict ,
661
+ async_inference_config_dict = { "OutputConfig" : { "S3OutputPath" : "s3://some-path" }} ,
614
662
)
615
663
616
664
0 commit comments