@@ -3194,6 +3194,7 @@ def test_batch_get_record(sagemaker_session):
3194
3194
IR_MODEL_PACKAGE_VERSION_ARN = (
3195
3195
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
3196
3196
)
3197
+ IR_MODEL_NAME = "MODEL_NAME"
3197
3198
IR_NEAREST_MODEL_NAME = "xgboost"
3198
3199
IR_SUPPORTED_INSTANCE_TYPES = ["ml.c5.xlarge" , "ml.c5.2xlarge" ]
3199
3200
IR_FRAMEWORK = "XGBOOST"
@@ -3243,6 +3244,29 @@ def create_inference_recommendations_job_default_happy_response():
3243
3244
"JobDescription" : "#python-sdk-create" ,
3244
3245
}
3245
3246
3247
+ def create_inference_recommendations_job_default_model_name_happy_response ():
3248
+ return {
3249
+ "JobName" : IR_USER_JOB_NAME ,
3250
+ "JobType" : "Default" ,
3251
+ "RoleArn" : IR_ROLE_ARN ,
3252
+ "InputConfig" : {
3253
+ "ContainerConfig" : {
3254
+ "Domain" : "MACHINE_LEARNING" ,
3255
+ "Task" : "OTHER" ,
3256
+ "Framework" : IR_FRAMEWORK ,
3257
+ "PayloadConfig" : {
3258
+ "SamplePayloadUrl" : IR_SAMPLE_PAYLOAD_URL ,
3259
+ "SupportedContentTypes" : IR_SUPPORTED_CONTENT_TYPES ,
3260
+ },
3261
+ "FrameworkVersion" : IR_FRAMEWORK_VERSION ,
3262
+ "NearestModelName" : IR_NEAREST_MODEL_NAME ,
3263
+ "SupportedInstanceTypes" : IR_SUPPORTED_INSTANCE_TYPES ,
3264
+ },
3265
+ "ModelName" : IR_MODEL_NAME ,
3266
+ },
3267
+ "JobDescription" : "#python-sdk-create" ,
3268
+ }
3269
+
3246
3270
3247
3271
def create_inference_recommendations_job_advanced_happy_response ():
3248
3272
base_advanced_job_response = create_inference_recommendations_job_default_happy_response ()
@@ -3258,6 +3282,20 @@ def create_inference_recommendations_job_advanced_happy_response():
3258
3282
return base_advanced_job_response
3259
3283
3260
3284
3285
+ def create_inference_recommendations_job_advanced_model_name_happy_response ():
3286
+ base_advanced_job_response = create_inference_recommendations_job_default_model_name_happy_response ()
3287
+
3288
+ base_advanced_job_response ["JobName" ] = IR_JOB_NAME
3289
+ base_advanced_job_response ["JobType" ] = IR_ADVANCED_JOB
3290
+ base_advanced_job_response ["StoppingConditions" ] = IR_STOPPING_CONDITIONS
3291
+ base_advanced_job_response ["InputConfig" ]["JobDurationInSeconds" ] = IR_JOB_DURATION_IN_SECONDS
3292
+ base_advanced_job_response ["InputConfig" ]["EndpointConfigurations" ] = IR_ENDPOINT_CONFIGURATIONS
3293
+ base_advanced_job_response ["InputConfig" ]["TrafficPattern" ] = IR_TRAFFIC_PATTERN
3294
+ base_advanced_job_response ["InputConfig" ]["ResourceLimit" ] = IR_RESOURCE_LIMIT
3295
+
3296
+ return base_advanced_job_response
3297
+
3298
+
3261
3299
def test_create_inference_recommendations_job_default_happy (sagemaker_session ):
3262
3300
job_name = sagemaker_session .create_inference_recommendations_job (
3263
3301
role = IR_ROLE_ARN ,
@@ -3304,6 +3342,89 @@ def test_create_inference_recommendations_job_advanced_happy(sagemaker_session):
3304
3342
assert IR_JOB_NAME == job_name
3305
3343
3306
3344
3345
+ def test_create_inference_recommendations_job_default_model_name_happy (sagemaker_session ):
3346
+ job_name = sagemaker_session .create_inference_recommendations_job (
3347
+ role = IR_ROLE_ARN ,
3348
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
3349
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
3350
+ model_name = IR_MODEL_NAME ,
3351
+ model_package_version_arn = None ,
3352
+ framework = IR_FRAMEWORK ,
3353
+ framework_version = IR_FRAMEWORK_VERSION ,
3354
+ nearest_model_name = IR_NEAREST_MODEL_NAME ,
3355
+ supported_instance_types = IR_SUPPORTED_INSTANCE_TYPES ,
3356
+ job_name = IR_USER_JOB_NAME ,
3357
+ )
3358
+
3359
+ sagemaker_session .sagemaker_client .create_inference_recommendations_job .assert_called_with (
3360
+ ** create_inference_recommendations_job_default_model_name_happy_response ()
3361
+ )
3362
+
3363
+ assert IR_USER_JOB_NAME == job_name
3364
+
3365
+ @patch ("uuid.uuid4" , MagicMock (return_value = "sample-unique-uuid" ))
3366
+ def test_create_inference_recommendations_job_advanced_model_name_happy (sagemaker_session ):
3367
+ job_name = sagemaker_session .create_inference_recommendations_job (
3368
+ role = IR_ROLE_ARN ,
3369
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
3370
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
3371
+ model_name = IR_MODEL_NAME ,
3372
+ model_package_version_arn = None ,
3373
+ framework = IR_FRAMEWORK ,
3374
+ framework_version = IR_FRAMEWORK_VERSION ,
3375
+ nearest_model_name = IR_NEAREST_MODEL_NAME ,
3376
+ supported_instance_types = IR_SUPPORTED_INSTANCE_TYPES ,
3377
+ endpoint_configurations = IR_ENDPOINT_CONFIGURATIONS ,
3378
+ traffic_pattern = IR_TRAFFIC_PATTERN ,
3379
+ stopping_conditions = IR_STOPPING_CONDITIONS ,
3380
+ resource_limit = IR_RESOURCE_LIMIT ,
3381
+ job_type = IR_ADVANCED_JOB ,
3382
+ job_duration_in_seconds = IR_JOB_DURATION_IN_SECONDS ,
3383
+ )
3384
+
3385
+ sagemaker_session .sagemaker_client .create_inference_recommendations_job .assert_called_with (
3386
+ ** create_inference_recommendations_job_advanced_model_name_happy_response ()
3387
+ )
3388
+
3389
+ assert IR_JOB_NAME == job_name
3390
+
3391
+ def test_create_inference_recommendations_job_missing_model_name_and_pkg (sagemaker_session ):
3392
+ with pytest .raises (
3393
+ ValueError ,
3394
+ match = "Missing model_name and model_package_version_arn, please provide one of them."
3395
+ ):
3396
+ sagemaker_session .create_inference_recommendations_job (
3397
+ role = IR_ROLE_ARN ,
3398
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
3399
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
3400
+ model_name = None ,
3401
+ model_package_version_arn = None ,
3402
+ framework = IR_FRAMEWORK ,
3403
+ framework_version = IR_FRAMEWORK_VERSION ,
3404
+ nearest_model_name = IR_NEAREST_MODEL_NAME ,
3405
+ supported_instance_types = IR_SUPPORTED_INSTANCE_TYPES ,
3406
+ job_name = IR_USER_JOB_NAME ,
3407
+ )
3408
+
3409
+ def test_create_inference_recommendations_job_provided_model_name_and_pkg (sagemaker_session ):
3410
+ with pytest .raises (
3411
+ ValueError ,
3412
+ match = "Please provide either model_name or model_package_version_arn should be provided, not both."
3413
+ ):
3414
+ sagemaker_session .create_inference_recommendations_job (
3415
+ role = IR_ROLE_ARN ,
3416
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
3417
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
3418
+ model_name = IR_MODEL_NAME ,
3419
+ model_package_version_arn = IR_MODEL_PACKAGE_VERSION_ARN ,
3420
+ framework = IR_FRAMEWORK ,
3421
+ framework_version = IR_FRAMEWORK_VERSION ,
3422
+ nearest_model_name = IR_NEAREST_MODEL_NAME ,
3423
+ supported_instance_types = IR_SUPPORTED_INSTANCE_TYPES ,
3424
+ job_name = IR_USER_JOB_NAME ,
3425
+ )
3426
+
3427
+
3307
3428
def test_create_inference_recommendations_job_propogate_validation_exception (sagemaker_session ):
3308
3429
validation_exception_message = (
3309
3430
"Failed to describe model due to validation failure with following error: test_error"
0 commit comments