@@ -175,6 +175,134 @@ def default_right_sized_model(model_package):
175
175
)
176
176
177
177
178
+ def test_right_size_default_with_model_name_successful (sagemaker_session , model ):
179
+ inference_recommender_model = model .right_size (
180
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
181
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
182
+ supported_instance_types = [IR_SAMPLE_INSTANCE_TYPE ],
183
+ job_name = IR_JOB_NAME ,
184
+ framework = IR_SAMPLE_FRAMEWORK ,
185
+ )
186
+
187
+ # assert that the create api has been called with default parameters with model name
188
+ assert sagemaker_session .create_inference_recommendations_job .called_with (
189
+ role = IR_ROLE_ARN ,
190
+ job_name = IR_JOB_NAME ,
191
+ job_type = "Default" ,
192
+ job_duration_in_seconds = None ,
193
+ model_name = ANY ,
194
+ model_package_version_arn = None ,
195
+ framework = IR_SAMPLE_FRAMEWORK ,
196
+ framework_version = None ,
197
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
198
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
199
+ supported_instance_types = [IR_SAMPLE_INSTANCE_TYPE ],
200
+ endpoint_configurations = None ,
201
+ traffic_pattern = None ,
202
+ stopping_conditions = None ,
203
+ resource_limit = None ,
204
+ )
205
+
206
+ assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
207
+
208
+ # confirm that the IR instance attributes have been set
209
+ assert (
210
+ inference_recommender_model .inference_recommender_job_results
211
+ == IR_SAMPLE_INFERENCE_RESPONSE
212
+ )
213
+ assert (
214
+ inference_recommender_model .inference_recommendations
215
+ == IR_SAMPLE_INFERENCE_RESPONSE ["InferenceRecommendations" ]
216
+ )
217
+
218
+ # confirm that the returned object of right_size is itself
219
+ assert inference_recommender_model == model
220
+
221
+ def test_right_size_advanced_list_instances_model_name_successful (sagemaker_session , model ):
222
+ inference_recommender_model = model .right_size (
223
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
224
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
225
+ framework = "SAGEMAKER-SCIKIT-LEARN" ,
226
+ job_duration_in_seconds = 7200 ,
227
+ hyperparameter_ranges = IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES ,
228
+ phases = IR_SAMPLE_PHASES ,
229
+ traffic_type = "PHASES" ,
230
+ max_invocations = 100 ,
231
+ model_latency_thresholds = IR_SAMPLE_MODEL_LATENCY_THRESHOLDS ,
232
+ max_tests = 5 ,
233
+ max_parallel_tests = 5 ,
234
+ )
235
+
236
+ # assert that the create api has been called with advanced parameters
237
+ assert sagemaker_session .create_inference_recommendations_job .called_with (
238
+ role = IR_ROLE_ARN ,
239
+ job_name = IR_JOB_NAME ,
240
+ job_type = "Advanced" ,
241
+ job_duration_in_seconds = 7200 ,
242
+ model_name = ANY ,
243
+ model_package_version_arn = None ,
244
+ framework = IR_SAMPLE_FRAMEWORK ,
245
+ framework_version = None ,
246
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
247
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
248
+ supported_instance_types = [IR_SAMPLE_INSTANCE_TYPE ],
249
+ endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
250
+ traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
251
+ stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
252
+ resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
253
+ )
254
+
255
+ assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
256
+
257
+ # confirm that the IR instance attributes have been set
258
+ assert (
259
+ inference_recommender_model .inference_recommender_job_results
260
+ == IR_SAMPLE_INFERENCE_RESPONSE
261
+ )
262
+ assert (
263
+ inference_recommender_model .inference_recommendations
264
+ == IR_SAMPLE_INFERENCE_RESPONSE ["InferenceRecommendations" ]
265
+ )
266
+
267
+ # confirm that the returned object of right_size is itself
268
+ assert inference_recommender_model == model
269
+
270
+ def test_right_size_advanced_single_instances_model_name_successful (sagemaker_session , model ):
271
+ model .right_size (
272
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
273
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
274
+ framework = "SAGEMAKER-SCIKIT-LEARN" ,
275
+ job_duration_in_seconds = 7200 ,
276
+ hyperparameter_ranges = IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES ,
277
+ phases = IR_SAMPLE_PHASES ,
278
+ traffic_type = "PHASES" ,
279
+ max_invocations = 100 ,
280
+ model_latency_thresholds = IR_SAMPLE_MODEL_LATENCY_THRESHOLDS ,
281
+ max_tests = 5 ,
282
+ max_parallel_tests = 5 ,
283
+ )
284
+
285
+ # assert that the create api has been called with advanced parameters
286
+ assert sagemaker_session .create_inference_recommendations_job .called_with (
287
+ role = IR_ROLE_ARN ,
288
+ job_name = IR_JOB_NAME ,
289
+ job_type = "Advanced" ,
290
+ job_duration_in_seconds = 7200 ,
291
+ model_name = ANY ,
292
+ model_package_version_arn = None ,
293
+ framework = IR_SAMPLE_FRAMEWORK ,
294
+ framework_version = None ,
295
+ sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
296
+ supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
297
+ supported_instance_types = [IR_SAMPLE_INSTANCE_TYPE ],
298
+ endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
299
+ traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
300
+ stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
301
+ resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
302
+ )
303
+
304
+
305
+
178
306
def test_right_size_default_with_model_package_successful (sagemaker_session , model_package ):
179
307
inference_recommender_model_pkg = model_package .right_size (
180
308
sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
@@ -190,6 +318,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
190
318
job_name = IR_JOB_NAME ,
191
319
job_type = "Default" ,
192
320
job_duration_in_seconds = None ,
321
+ model_name = None ,
193
322
model_package_version_arn = model_package .model_package_arn ,
194
323
framework = IR_SAMPLE_FRAMEWORK ,
195
324
framework_version = None ,
@@ -202,7 +331,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
202
331
resource_limit = None ,
203
332
)
204
333
205
- assert sagemaker_session .wait_for_inference_recomendations_job .called_with (IR_JOB_NAME )
334
+ assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
206
335
207
336
# confirm that the IR instance attributes have been set
208
337
assert (
@@ -216,7 +345,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
216
345
217
346
# confirm that the returned object of right_size is itself
218
347
assert inference_recommender_model_pkg == model_package
219
-
348
+
220
349
221
350
def test_right_size_advanced_list_instances_model_package_successful (
222
351
sagemaker_session , model_package
@@ -253,7 +382,7 @@ def test_right_size_advanced_list_instances_model_package_successful(
253
382
resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
254
383
)
255
384
256
- assert sagemaker_session .wait_for_inference_recomendations_job .called_with (IR_JOB_NAME )
385
+ assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
257
386
258
387
# confirm that the IR instance attributes have been set
259
388
assert (
@@ -359,21 +488,6 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa
359
488
)
360
489
361
490
362
- # TODO -> removed once model registry is decoupled
363
- def test_right_size_missing_model_package_arn (sagemaker_session , model ):
364
- with pytest .raises (
365
- ValueError ,
366
- match = "right_size\\ (\\ ) is currently only supported with a registered model" ,
367
- ):
368
- model .right_size (
369
- sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
370
- supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
371
- supported_instance_types = [IR_SAMPLE_INSTANCE_TYPE ],
372
- job_name = IR_JOB_NAME ,
373
- framework = IR_SAMPLE_FRAMEWORK ,
374
- )
375
-
376
-
377
491
# TODO check our framework mapping when we add in inference_recommendation_id support
378
492
379
493
0 commit comments