Skip to content

Commit 4a93d66

Browse files
author
Jinpeng Qi
committed
Refactor constructs
1 parent 42fd779 commit 4a93d66

File tree

4 files changed

+46
-55
lines changed

4 files changed

+46
-55
lines changed

src/sagemaker/inference_recommender/__init__.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,4 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for using Inference Recommender with Amazon SageMaker."""
1414
from __future__ import absolute_import
15-
16-
17-
class Phase:
18-
"""Used to store phases of a traffic pattern to perform endpoint load testing.
19-
20-
Required for an Advanced Inference Recommendations Job
21-
"""
22-
23-
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
24-
"""Initialze a `Phase`"""
25-
self.to_json = {
26-
"DurationInSeconds": duration_in_seconds,
27-
"InitialNumberOfUsers": initial_number_of_users,
28-
"SpawnRate": spawn_rate,
29-
}
30-
31-
32-
class ModelLatencyThreshold:
33-
"""Used to store inference request/response latency to perform endpoint load testing.
34-
35-
Required for an Advanced Inference Recommendations Job
36-
"""
37-
38-
def __init__(self, percentile: str, value_in_milliseconds: int):
39-
"""Initialze a `ModelLatencyThreshold`"""
40-
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
15+
from inference_recommender_mixin import Phase, ModelLatencyThreshold # noqa: F401

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from typing import List, Dict, Optional
2020
import sagemaker
21-
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
2221
from sagemaker.parameter import CategoricalParameter
2322

2423
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -32,6 +31,32 @@
3231
LOGGER = logging.getLogger("sagemaker")
3332

3433

34+
class Phase:
35+
"""Used to store phases of a traffic pattern to perform endpoint load testing.
36+
37+
Required for an Advanced Inference Recommendations Job
38+
"""
39+
40+
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
41+
"""Initialze a `Phase`"""
42+
self.to_json = {
43+
"DurationInSeconds": duration_in_seconds,
44+
"InitialNumberOfUsers": initial_number_of_users,
45+
"SpawnRate": spawn_rate,
46+
}
47+
48+
49+
class ModelLatencyThreshold:
50+
"""Used to store inference request/response latency to perform endpoint load testing.
51+
52+
Required for an Advanced Inference Recommendations Job
53+
"""
54+
55+
def __init__(self, percentile: str, value_in_milliseconds: int):
56+
"""Initialze a `ModelLatencyThreshold`"""
57+
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
58+
59+
3560
class InferenceRecommenderMixin:
3661
"""A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``"""
3762

@@ -439,36 +464,24 @@ def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: in
439464
"""Bundle right_size() parameters into a resource limit for Advanced job"""
440465
if not max_tests and not max_parallel_tests:
441466
return None
442-
if max_tests and not max_parallel_tests:
443-
return {
444-
"MaxNumberOfTests": max_tests,
445-
}
446-
if not max_tests and max_parallel_tests:
447-
return {
448-
"MaxParallelOfTests": max_parallel_tests,
449-
}
450-
return {
451-
"MaxNumberOfTests": max_tests,
452-
"MaxParallelOfTests": max_parallel_tests,
453-
}
467+
resource_limit = {}
468+
if max_tests:
469+
resource_limit["MaxNumberOfTests"] = max_tests
470+
if max_parallel_tests:
471+
resource_limit["MaxParallelOfTests"] = max_parallel_tests
472+
return resource_limit
454473

455474
def _convert_to_stopping_conditions_json(
456475
self, max_invocations: int, model_latency_thresholds: List[ModelLatencyThreshold]
457476
):
458477
"""Bundle right_size() parameters into stopping conditions for Advanced job"""
459478
if not max_invocations and not model_latency_thresholds:
460479
return None
461-
if max_invocations and not model_latency_thresholds:
462-
return {
463-
"MaxInvocations": max_invocations,
464-
}
465-
if not max_invocations and model_latency_thresholds:
466-
return {
467-
"ModelLatencyThresholds": [
468-
threshold.to_json for threshold in model_latency_thresholds
469-
],
470-
}
471-
return {
472-
"MaxInvocations": max_invocations,
473-
"ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds],
474-
}
480+
stopping_conditions = {}
481+
if max_invocations:
482+
stopping_conditions["MaxInvocations"] = max_invocations
483+
if model_latency_thresholds:
484+
stopping_conditions["ModelLatencyThresholds"] = [
485+
threshold.to_json for threshold in model_latency_thresholds
486+
]
487+
return stopping_conditions

tests/integ/test_inference_recommender.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tests.integ import DATA_DIR
2222
from tests.integ.timeout import timeout
2323
import pandas as pd
24-
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
24+
from sagemaker.inference_recommender.inference_recommender_mixin import Phase, ModelLatencyThreshold
2525
from sagemaker.parameter import CategoricalParameter
2626
import logging
2727

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
from sagemaker.model import Model, ModelPackage
66
from sagemaker.parameter import CategoricalParameter
7-
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
7+
from sagemaker.inference_recommender.inference_recommender_mixin import (
8+
Phase,
9+
ModelLatencyThreshold,
10+
)
811
from sagemaker.async_inference import AsyncInferenceConfig
912
from sagemaker.serverless import ServerlessInferenceConfig
1013

0 commit comments

Comments
 (0)