|
18 | 18 |
|
19 | 19 | from typing import List, Dict, Optional
|
20 | 20 | import sagemaker
|
21 |
| -from sagemaker.inference_recommender import ModelLatencyThreshold, Phase |
22 | 21 | from sagemaker.parameter import CategoricalParameter
|
23 | 22 |
|
24 | 23 | INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
|
|
32 | 31 | LOGGER = logging.getLogger("sagemaker")
|
33 | 32 |
|
34 | 33 |
|
| 34 | +class Phase: |
| 35 | + """Used to store phases of a traffic pattern to perform endpoint load testing. |
| 36 | +
|
| 37 | + Required for an Advanced Inference Recommendations Job |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int): |
| 41 | + """Initialze a `Phase`""" |
| 42 | + self.to_json = { |
| 43 | + "DurationInSeconds": duration_in_seconds, |
| 44 | + "InitialNumberOfUsers": initial_number_of_users, |
| 45 | + "SpawnRate": spawn_rate, |
| 46 | + } |
| 47 | + |
| 48 | + |
| 49 | +class ModelLatencyThreshold: |
| 50 | + """Used to store inference request/response latency to perform endpoint load testing. |
| 51 | +
|
| 52 | + Required for an Advanced Inference Recommendations Job |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__(self, percentile: str, value_in_milliseconds: int): |
| 56 | + """Initialze a `ModelLatencyThreshold`""" |
| 57 | + self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds} |
| 58 | + |
| 59 | + |
35 | 60 | class InferenceRecommenderMixin:
|
36 | 61 | """A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``"""
|
37 | 62 |
|
@@ -439,36 +464,24 @@ def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: in
|
439 | 464 | """Bundle right_size() parameters into a resource limit for Advanced job"""
|
440 | 465 | if not max_tests and not max_parallel_tests:
|
441 | 466 | return None
|
442 |
| - if max_tests and not max_parallel_tests: |
443 |
| - return { |
444 |
| - "MaxNumberOfTests": max_tests, |
445 |
| - } |
446 |
| - if not max_tests and max_parallel_tests: |
447 |
| - return { |
448 |
| - "MaxParallelOfTests": max_parallel_tests, |
449 |
| - } |
450 |
| - return { |
451 |
| - "MaxNumberOfTests": max_tests, |
452 |
| - "MaxParallelOfTests": max_parallel_tests, |
453 |
| - } |
| 467 | + resource_limit = {} |
| 468 | + if max_tests: |
| 469 | + resource_limit["MaxNumberOfTests"] = max_tests |
| 470 | + if max_parallel_tests: |
| 471 | + resource_limit["MaxParallelOfTests"] = max_parallel_tests |
| 472 | + return resource_limit |
454 | 473 |
|
455 | 474 | def _convert_to_stopping_conditions_json(
|
456 | 475 | self, max_invocations: int, model_latency_thresholds: List[ModelLatencyThreshold]
|
457 | 476 | ):
|
458 | 477 | """Bundle right_size() parameters into stopping conditions for Advanced job"""
|
459 | 478 | if not max_invocations and not model_latency_thresholds:
|
460 | 479 | return None
|
461 |
| - if max_invocations and not model_latency_thresholds: |
462 |
| - return { |
463 |
| - "MaxInvocations": max_invocations, |
464 |
| - } |
465 |
| - if not max_invocations and model_latency_thresholds: |
466 |
| - return { |
467 |
| - "ModelLatencyThresholds": [ |
468 |
| - threshold.to_json for threshold in model_latency_thresholds |
469 |
| - ], |
470 |
| - } |
471 |
| - return { |
472 |
| - "MaxInvocations": max_invocations, |
473 |
| - "ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds], |
474 |
| - } |
| 480 | + stopping_conditions = {} |
| 481 | + if max_invocations: |
| 482 | + stopping_conditions["MaxInvocations"] = max_invocations |
| 483 | + if model_latency_thresholds: |
| 484 | + stopping_conditions["ModelLatencyThresholds"] = [ |
| 485 | + threshold.to_json for threshold in model_latency_thresholds |
| 486 | + ] |
| 487 | + return stopping_conditions |
0 commit comments