Skip to content

Commit 42fd779

Browse files
author
Jinpeng Qi
committed
fix: advanced inference recommendation jobs parameters check
1 parent 8d2c16b commit 42fd779

File tree

4 files changed

+47
-31
lines changed

4 files changed

+47
-31
lines changed

src/sagemaker/inference_recommender/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,29 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for using Inference Recommender with Amazon SageMaker."""
1414
from __future__ import absolute_import
15+
16+
17+
class Phase:
18+
"""Used to store phases of a traffic pattern to perform endpoint load testing.
19+
20+
Required for an Advanced Inference Recommendations Job
21+
"""
22+
23+
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
24+
"""Initialze a `Phase`"""
25+
self.to_json = {
26+
"DurationInSeconds": duration_in_seconds,
27+
"InitialNumberOfUsers": initial_number_of_users,
28+
"SpawnRate": spawn_rate,
29+
}
30+
31+
32+
class ModelLatencyThreshold:
33+
"""Used to store inference request/response latency to perform endpoint load testing.
34+
35+
Required for an Advanced Inference Recommendations Job
36+
"""
37+
38+
def __init__(self, percentile: str, value_in_milliseconds: int):
39+
"""Initialze a `ModelLatencyThreshold`"""
40+
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from typing import List, Dict, Optional
2020
import sagemaker
21+
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
2122
from sagemaker.parameter import CategoricalParameter
2223

2324
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -31,32 +32,6 @@
3132
LOGGER = logging.getLogger("sagemaker")
3233

3334

34-
class Phase:
35-
"""Used to store phases of a traffic pattern to perform endpoint load testing.
36-
37-
Required for an Advanced Inference Recommendations Job
38-
"""
39-
40-
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
41-
"""Initialze a `Phase`"""
42-
self.to_json = {
43-
"DurationInSeconds": duration_in_seconds,
44-
"InitialNumberOfUsers": initial_number_of_users,
45-
"SpawnRate": spawn_rate,
46-
}
47-
48-
49-
class ModelLatencyThreshold:
50-
"""Used to store inference request/response latency to perform endpoint load testing.
51-
52-
Required for an Advanced Inference Recommendations Job
53-
"""
54-
55-
def __init__(self, percentile: str, value_in_milliseconds: int):
56-
"""Initialze a `ModelLatencyThreshold`"""
57-
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
58-
59-
6035
class InferenceRecommenderMixin:
6136
"""A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``"""
6237

@@ -464,6 +439,14 @@ def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: in
464439
"""Bundle right_size() parameters into a resource limit for Advanced job"""
465440
if not max_tests and not max_parallel_tests:
466441
return None
442+
if max_tests and not max_parallel_tests:
443+
return {
444+
"MaxNumberOfTests": max_tests,
445+
}
446+
if not max_tests and max_parallel_tests:
447+
return {
448+
"MaxParallelOfTests": max_parallel_tests,
449+
}
467450
return {
468451
"MaxNumberOfTests": max_tests,
469452
"MaxParallelOfTests": max_parallel_tests,
@@ -475,6 +458,16 @@ def _convert_to_stopping_conditions_json(
475458
"""Bundle right_size() parameters into stopping conditions for Advanced job"""
476459
if not max_invocations and not model_latency_thresholds:
477460
return None
461+
if max_invocations and not model_latency_thresholds:
462+
return {
463+
"MaxInvocations": max_invocations,
464+
}
465+
if not max_invocations and model_latency_thresholds:
466+
return {
467+
"ModelLatencyThresholds": [
468+
threshold.to_json for threshold in model_latency_thresholds
469+
],
470+
}
478471
return {
479472
"MaxInvocations": max_invocations,
480473
"ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds],

tests/integ/test_inference_recommender.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tests.integ import DATA_DIR
2222
from tests.integ.timeout import timeout
2323
import pandas as pd
24-
from sagemaker.inference_recommender.inference_recommender_mixin import Phase, ModelLatencyThreshold
24+
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
2525
from sagemaker.parameter import CategoricalParameter
2626
import logging
2727

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44

55
from sagemaker.model import Model, ModelPackage
66
from sagemaker.parameter import CategoricalParameter
7-
from sagemaker.inference_recommender.inference_recommender_mixin import (
8-
Phase,
9-
ModelLatencyThreshold,
10-
)
7+
from sagemaker.inference_recommender import ModelLatencyThreshold, Phase
118
from sagemaker.async_inference import AsyncInferenceConfig
129
from sagemaker.serverless import ServerlessInferenceConfig
1310

0 commit comments

Comments
 (0)