18
18
19
19
from typing import List , Dict , Optional
20
20
import sagemaker
21
+ from sagemaker .inference_recommender import ModelLatencyThreshold , Phase
21
22
from sagemaker .parameter import CategoricalParameter
22
23
23
24
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
31
32
LOGGER = logging .getLogger ("sagemaker" )
32
33
33
34
34
- class Phase :
35
- """Used to store phases of a traffic pattern to perform endpoint load testing.
36
-
37
- Required for an Advanced Inference Recommendations Job
38
- """
39
-
40
- def __init__ (self , duration_in_seconds : int , initial_number_of_users : int , spawn_rate : int ):
41
- """Initialze a `Phase`"""
42
- self .to_json = {
43
- "DurationInSeconds" : duration_in_seconds ,
44
- "InitialNumberOfUsers" : initial_number_of_users ,
45
- "SpawnRate" : spawn_rate ,
46
- }
47
-
48
-
49
- class ModelLatencyThreshold :
50
- """Used to store inference request/response latency to perform endpoint load testing.
51
-
52
- Required for an Advanced Inference Recommendations Job
53
- """
54
-
55
- def __init__ (self , percentile : str , value_in_milliseconds : int ):
56
- """Initialze a `ModelLatencyThreshold`"""
57
- self .to_json = {"Percentile" : percentile , "ValueInMilliseconds" : value_in_milliseconds }
58
-
59
-
60
35
class InferenceRecommenderMixin :
61
36
"""A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``"""
62
37
@@ -464,6 +439,14 @@ def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: in
464
439
"""Bundle right_size() parameters into a resource limit for Advanced job"""
465
440
if not max_tests and not max_parallel_tests :
466
441
return None
442
+ if max_tests and not max_parallel_tests :
443
+ return {
444
+ "MaxNumberOfTests" : max_tests ,
445
+ }
446
+ if not max_tests and max_parallel_tests :
447
+ return {
448
+ "MaxParallelOfTests" : max_parallel_tests ,
449
+ }
467
450
return {
468
451
"MaxNumberOfTests" : max_tests ,
469
452
"MaxParallelOfTests" : max_parallel_tests ,
@@ -475,6 +458,16 @@ def _convert_to_stopping_conditions_json(
475
458
"""Bundle right_size() parameters into stopping conditions for Advanced job"""
476
459
if not max_invocations and not model_latency_thresholds :
477
460
return None
461
+ if max_invocations and not model_latency_thresholds :
462
+ return {
463
+ "MaxInvocations" : max_invocations ,
464
+ }
465
+ if not max_invocations and model_latency_thresholds :
466
+ return {
467
+ "ModelLatencyThresholds" : [
468
+ threshold .to_json for threshold in model_latency_thresholds
469
+ ],
470
+ }
478
471
return {
479
472
"MaxInvocations" : max_invocations ,
480
473
"ModelLatencyThresholds" : [threshold .to_json for threshold in model_latency_thresholds ],
0 commit comments