@@ -159,7 +159,12 @@ def test_deploy_accelerator_type(
159
159
accelerator_type = ACCELERATOR_TYPE ,
160
160
)
161
161
162
- create_sagemaker_model .assert_called_with (INSTANCE_TYPE , ACCELERATOR_TYPE , None , None )
162
+ create_sagemaker_model .assert_called_with (
163
+ instance_type = INSTANCE_TYPE ,
164
+ accelerator_type = ACCELERATOR_TYPE ,
165
+ tags = None ,
166
+ serverless_inference_config = None ,
167
+ )
163
168
production_variant .assert_called_with (
164
169
MODEL_NAME ,
165
170
INSTANCE_TYPE ,
@@ -271,7 +276,12 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base,
271
276
tags = [{"Key" : "ModelName" , "Value" : "TestModel" }]
272
277
model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = INSTANCE_COUNT , tags = tags )
273
278
274
- create_sagemaker_model .assert_called_with (INSTANCE_TYPE , None , tags , None )
279
+ create_sagemaker_model .assert_called_with (
280
+ instance_type = INSTANCE_TYPE ,
281
+ accelerator_type = None ,
282
+ tags = tags ,
283
+ serverless_inference_config = None ,
284
+ )
275
285
sagemaker_session .endpoint_from_production_variants .assert_called_with (
276
286
name = ENDPOINT_NAME ,
277
287
production_variants = [BASE_PRODUCTION_VARIANT ],
@@ -463,7 +473,12 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
463
473
serverless_inference_config = serverless_inference_config ,
464
474
)
465
475
466
- create_sagemaker_model .assert_called_with (None , None , None , serverless_inference_config )
476
+ create_sagemaker_model .assert_called_with (
477
+ instance_type = None ,
478
+ accelerator_type = None ,
479
+ tags = None ,
480
+ serverless_inference_config = serverless_inference_config ,
481
+ )
467
482
production_variant .assert_called_with (
468
483
MODEL_NAME ,
469
484
None ,
0 commit comments