Skip to content

Commit a776dc6

Browse files
SSRraymondRaymond Liu
andauthored
fix: Avoid double encoding to JSON in InferenceRecommenderMixin (#3697)
Co-authored-by: Raymond Liu <[email protected]>
1 parent 5fecbcc commit a776dc6

File tree

2 files changed

+110
-64
lines changed

2 files changed

+110
-64
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,9 @@ def _convert_to_endpoint_configurations_json(
463463
parameter_range.pop("instance_types")
464464

465465
for instance_type in instance_types:
466-
parameter_ranges = []
467-
for name, param in parameter_range.items():
468-
as_json = param.as_json_range(name)
469-
as_json["Value"] = as_json.pop("Values")
470-
parameter_ranges.append(as_json)
466+
parameter_ranges = [
467+
{"Name": name, "Value": param.values} for name, param in parameter_range.items()
468+
]
471469
endpoint_configurations_to_json.append(
472470
{
473471
"EnvironmentParameterRanges": {

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 107 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,41 @@
141141
"MaxParallelOfTests": 5,
142142
}
143143

144+
IR_SAMPLE_PRIMARY_CONTAINER = {
145+
"Image": "model-image-for-ir",
146+
"Environment": {},
147+
"ModelDataUrl": "s3://bucket/model.tar.gz",
148+
}
149+
150+
IR_PRODUCTION_VARIANTS = [
151+
{
152+
"ModelName": "model-name-for-ir",
153+
"VariantName": "AllTraffic",
154+
"InitialVariantWeight": 1,
155+
"InitialInstanceCount": 1,
156+
"InstanceType": "ml.m5.xlarge",
157+
}
158+
]
159+
160+
IR_OVERRIDDEN_PRODUCTION_VARIANTS = [
161+
{
162+
"ModelName": "model-name-for-ir",
163+
"VariantName": "AllTraffic",
164+
"InitialVariantWeight": 1,
165+
"InitialInstanceCount": 5,
166+
"InstanceType": "ml.c5.2xlarge",
167+
}
168+
]
169+
170+
IR_SERVERLESS_PRODUCTION_VARIANTS = [
171+
{
172+
"ModelName": "model-name-for-ir",
173+
"VariantName": "AllTraffic",
174+
"InitialVariantWeight": 1,
175+
"ServerlessConfig": {"MemorySizeInMB": 2048, "MaxConcurrency": 5},
176+
}
177+
]
178+
144179

145180
@pytest.fixture()
146181
def sagemaker_session():
@@ -185,17 +220,17 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
185220
framework=IR_SAMPLE_FRAMEWORK,
186221
)
187222

188-
assert sagemaker_session.create_model.called_with(
223+
sagemaker_session.create_model.assert_called_with(
189224
name=ANY,
190225
role=IR_ROLE_ARN,
191226
container_defs=None,
192-
primary_container={},
227+
primary_container=IR_SAMPLE_PRIMARY_CONTAINER,
193228
vpc_config=None,
194229
enable_network_isolation=False,
195230
)
196231

197232
# assert that the create api has been called with default parameters with model name
198-
assert sagemaker_session.create_inference_recommendations_job.called_with(
233+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
199234
role=IR_ROLE_ARN,
200235
job_name=IR_JOB_NAME,
201236
job_type="Default",
@@ -213,7 +248,9 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
213248
resource_limit=None,
214249
)
215250

216-
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
251+
sagemaker_session.wait_for_inference_recommendations_job.assert_called_with(
252+
IR_JOB_NAME, log_level="Verbose"
253+
)
217254

218255
# confirm that the IR instance attributes have been set
219256
assert (
@@ -232,6 +269,7 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
232269
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
233270
def test_right_size_advanced_list_instances_model_name_successful(sagemaker_session, model):
234271
inference_recommender_model = model.right_size(
272+
job_name=IR_JOB_NAME,
235273
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
236274
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
237275
framework="SAGEMAKER-SCIKIT-LEARN",
@@ -246,7 +284,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
246284
)
247285

248286
# assert that the create api has been called with advanced parameters
249-
assert sagemaker_session.create_inference_recommendations_job.called_with(
287+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
250288
role=IR_ROLE_ARN,
251289
job_name=IR_JOB_NAME,
252290
job_type="Advanced",
@@ -256,15 +294,17 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
256294
framework=IR_SAMPLE_FRAMEWORK,
257295
framework_version=None,
258296
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
259-
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
260-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
297+
supported_content_types=["text/csv"],
298+
supported_instance_types=None,
261299
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
262300
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
263301
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
264302
resource_limit=IR_SAMPLE_RESOURCE_LIMIT,
265303
)
266304

267-
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
305+
sagemaker_session.wait_for_inference_recommendations_job.assert_called_with(
306+
IR_JOB_NAME, log_level="Verbose"
307+
)
268308

269309
# confirm that the IR instance attributes have been set
270310
assert (
@@ -283,6 +323,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
283323
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
284324
def test_right_size_advanced_single_instances_model_name_successful(sagemaker_session, model):
285325
model.right_size(
326+
job_name=IR_JOB_NAME,
286327
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
287328
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
288329
framework="SAGEMAKER-SCIKIT-LEARN",
@@ -297,7 +338,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
297338
)
298339

299340
# assert that the create api has been called with advanced parameters
300-
assert sagemaker_session.create_inference_recommendations_job.called_with(
341+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
301342
role=IR_ROLE_ARN,
302343
job_name=IR_JOB_NAME,
303344
job_type="Advanced",
@@ -308,7 +349,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
308349
framework_version=None,
309350
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
310351
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
311-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
352+
supported_instance_types=None,
312353
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
313354
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
314355
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
@@ -326,7 +367,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
326367
)
327368

328369
# assert that the create api has been called with default parameters
329-
assert sagemaker_session.create_inference_recommendations_job.called_with(
370+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
330371
role=IR_ROLE_ARN,
331372
job_name=IR_JOB_NAME,
332373
job_type="Default",
@@ -344,7 +385,9 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
344385
resource_limit=None,
345386
)
346387

347-
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
388+
sagemaker_session.wait_for_inference_recommendations_job.assert_called_with(
389+
IR_JOB_NAME, log_level="Verbose"
390+
)
348391

349392
# confirm that the IR instance attributes have been set
350393
assert (
@@ -364,6 +407,7 @@ def test_right_size_advanced_list_instances_model_package_successful(
364407
sagemaker_session, model_package
365408
):
366409
inference_recommender_model_pkg = model_package.right_size(
410+
job_name=IR_JOB_NAME,
367411
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
368412
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
369413
framework="SAGEMAKER-SCIKIT-LEARN",
@@ -378,24 +422,27 @@ def test_right_size_advanced_list_instances_model_package_successful(
378422
)
379423

380424
# assert that the create api has been called with advanced parameters
381-
assert sagemaker_session.create_inference_recommendations_job.called_with(
425+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
382426
role=IR_ROLE_ARN,
383427
job_name=IR_JOB_NAME,
384428
job_type="Advanced",
385429
job_duration_in_seconds=7200,
430+
model_name=None,
386431
model_package_version_arn=model_package.model_package_arn,
387432
framework=IR_SAMPLE_FRAMEWORK,
388433
framework_version=None,
389434
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
390435
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
391-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
436+
supported_instance_types=None,
392437
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
393438
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
394439
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
395440
resource_limit=IR_SAMPLE_RESOURCE_LIMIT,
396441
)
397442

398-
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
443+
sagemaker_session.wait_for_inference_recommendations_job.assert_called_with(
444+
IR_JOB_NAME, log_level="Verbose"
445+
)
399446

400447
# confirm that the IR instance attributes have been set
401448
assert (
@@ -415,6 +462,7 @@ def test_right_size_advanced_single_instances_model_package_successful(
415462
sagemaker_session, model_package
416463
):
417464
model_package.right_size(
465+
job_name=IR_JOB_NAME,
418466
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
419467
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
420468
framework="SAGEMAKER-SCIKIT-LEARN",
@@ -429,17 +477,18 @@ def test_right_size_advanced_single_instances_model_package_successful(
429477
)
430478

431479
# assert that the create api has been called with advanced parameters
432-
assert sagemaker_session.create_inference_recommendations_job.called_with(
480+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
433481
role=IR_ROLE_ARN,
434482
job_name=IR_JOB_NAME,
435483
job_type="Advanced",
436484
job_duration_in_seconds=7200,
485+
model_name=None,
437486
model_package_version_arn=model_package.model_package_arn,
438487
framework=IR_SAMPLE_FRAMEWORK,
439488
framework_version=None,
440489
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
441490
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
442-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
491+
supported_instance_types=None,
443492
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
444493
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
445494
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
@@ -451,6 +500,7 @@ def test_right_size_advanced_model_package_partial_params_successful(
451500
sagemaker_session, model_package
452501
):
453502
model_package.right_size(
503+
job_name=IR_JOB_NAME,
454504
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
455505
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
456506
framework="SAGEMAKER-SCIKIT-LEARN",
@@ -463,17 +513,18 @@ def test_right_size_advanced_model_package_partial_params_successful(
463513
)
464514

465515
# assert that the create api has been called with advanced parameters
466-
assert sagemaker_session.create_inference_recommendations_job.called_with(
516+
sagemaker_session.create_inference_recommendations_job.assert_called_with(
467517
role=IR_ROLE_ARN,
468518
job_name=IR_JOB_NAME,
469519
job_type="Advanced",
470520
job_duration_in_seconds=7200,
521+
model_name=None,
471522
model_package_version_arn=model_package.model_package_arn,
472523
framework=IR_SAMPLE_FRAMEWORK,
473524
framework_version=None,
474525
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
475526
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
476-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
527+
supported_instance_types=None,
477528
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
478529
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
479530
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
@@ -504,45 +555,42 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa
504555
# TODO check our framework mapping when we add in inference_recommendation_id support
505556

506557

507-
@patch("sagemaker.production_variant")
508-
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
509558
def test_deploy_right_size_with_model_package_succeeds(
510-
production_variant, default_right_sized_model
559+
sagemaker_session, default_right_sized_model
511560
):
561+
562+
default_right_sized_model.name = MODEL_NAME
512563
default_right_sized_model.deploy(endpoint_name=IR_DEPLOY_ENDPOINT_NAME)
513564

514-
assert production_variant.called_with(
515-
model_name=MODEL_NAME,
516-
instance_type=IR_RIGHT_SIZE_INSTANCE_TYPE,
517-
initial_instance_count=IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT,
518-
accelerator_type=None,
519-
serverless_inference_config=None,
520-
volume_size=None,
521-
model_data_download_timeout=None,
522-
container_startup_health_check_timeout=None,
565+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
566+
async_inference_config_dict=None,
567+
data_capture_config_dict=None,
568+
kms_key=None,
569+
name="ir-endpoint-test",
570+
production_variants=IR_PRODUCTION_VARIANTS,
571+
tags=None,
572+
wait=True,
523573
)
524574

525575

526-
@patch("sagemaker.production_variant")
527-
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
528576
def test_deploy_right_size_with_both_overrides_succeeds(
529-
production_variant, default_right_sized_model
577+
sagemaker_session, default_right_sized_model
530578
):
579+
default_right_sized_model.name = MODEL_NAME
531580
default_right_sized_model.deploy(
532581
instance_type="ml.c5.2xlarge",
533582
initial_instance_count=5,
534583
endpoint_name=IR_DEPLOY_ENDPOINT_NAME,
535584
)
536585

537-
assert production_variant.called_with(
538-
model_name=MODEL_NAME,
539-
instance_type="ml.c5.2xlarge",
540-
initial_instance_count=5,
541-
accelerator_type=None,
542-
serverless_inference_config=None,
543-
volume_size=None,
544-
model_data_download_timeout=None,
545-
container_startup_health_check_timeout=None,
586+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
587+
async_inference_config_dict=None,
588+
data_capture_config_dict=None,
589+
kms_key=None,
590+
name="ir-endpoint-test",
591+
production_variants=IR_OVERRIDDEN_PRODUCTION_VARIANTS,
592+
tags=None,
593+
wait=True,
546594
)
547595

548596

@@ -576,41 +624,41 @@ def test_deploy_right_size_accelerator_type_fails(default_right_sized_model):
576624
default_right_sized_model.deploy(accelerator_type="ml.eia.medium")
577625

578626

579-
@patch("sagemaker.production_variant")
580-
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
581-
def test_deploy_right_size_serverless_override(production_variant, default_right_sized_model):
627+
@patch("sagemaker.utils.name_from_base", MagicMock(return_value=MODEL_NAME))
628+
def test_deploy_right_size_serverless_override(sagemaker_session, default_right_sized_model):
629+
default_right_sized_model.name = MODEL_NAME
582630
serverless_inference_config = ServerlessInferenceConfig()
583631
default_right_sized_model.deploy(serverless_inference_config=serverless_inference_config)
584632

585-
assert production_variant.called_with(
586-
model_name=MODEL_NAME,
587-
instance_type=None,
588-
initial_instance_count=None,
589-
accelerator_type=None,
590-
serverless_inference_config=serverless_inference_config._to_request_dict,
591-
volume_size=None,
592-
model_data_download_timeout=None,
593-
container_startup_health_check_timeout=None,
633+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
634+
name=MODEL_NAME,
635+
production_variants=IR_SERVERLESS_PRODUCTION_VARIANTS,
636+
tags=None,
637+
kms_key=None,
638+
wait=True,
639+
data_capture_config_dict=None,
640+
async_inference_config_dict=None,
594641
)
595642

596643

597-
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
644+
@patch("sagemaker.utils.name_from_base", MagicMock(return_value=MODEL_NAME))
598645
def test_deploy_right_size_async_override(sagemaker_session, default_right_sized_model):
646+
default_right_sized_model.name = MODEL_NAME
599647
async_inference_config = AsyncInferenceConfig(output_path="s3://some-path")
600648
default_right_sized_model.deploy(
601649
instance_type="ml.c5.2xlarge",
602650
initial_instance_count=1,
603651
async_inference_config=async_inference_config,
604652
)
605653

606-
assert sagemaker_session.endpoint_from_production_variants.called_with(
654+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
607655
name=MODEL_NAME,
608656
production_variants=[ANY],
609657
tags=None,
610658
kms_key=None,
611-
wait=None,
659+
wait=True,
612660
data_capture_config_dict=None,
613-
async_inference_config_dict=async_inference_config._to_request_dict,
661+
async_inference_config_dict={"OutputConfig": {"S3OutputPath": "s3://some-path"}},
614662
)
615663

616664

0 commit comments

Comments
 (0)