Skip to content

Commit cc9b286

Browse files
jinpengqiJinpeng Qi
andauthored
fix: advanced inference recommendation jobs parameters check (#3644)
Co-authored-by: Jinpeng Qi <[email protected]>
1 parent b731021 commit cc9b286

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/sagemaker/inference_recommender/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for using Inference Recommender with Amazon SageMaker."""
1414
from __future__ import absolute_import
15+
from sagemaker.inference_recommender.inference_recommender_mixin import ( # noqa: F401
16+
Phase,
17+
ModelLatencyThreshold,
18+
)

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,18 +464,24 @@ def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: in
464464
"""Bundle right_size() parameters into a resource limit for Advanced job"""
465465
if not max_tests and not max_parallel_tests:
466466
return None
467-
return {
468-
"MaxNumberOfTests": max_tests,
469-
"MaxParallelOfTests": max_parallel_tests,
470-
}
467+
resource_limit = {}
468+
if max_tests:
469+
resource_limit["MaxNumberOfTests"] = max_tests
470+
if max_parallel_tests:
471+
resource_limit["MaxParallelOfTests"] = max_parallel_tests
472+
return resource_limit
471473

472474
def _convert_to_stopping_conditions_json(
473475
self, max_invocations: int, model_latency_thresholds: List[ModelLatencyThreshold]
474476
):
475477
"""Bundle right_size() parameters into stopping conditions for Advanced job"""
476478
if not max_invocations and not model_latency_thresholds:
477479
return None
478-
return {
479-
"MaxInvocations": max_invocations,
480-
"ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds],
481-
}
480+
stopping_conditions = {}
481+
if max_invocations:
482+
stopping_conditions["MaxInvocations"] = max_invocations
483+
if model_latency_thresholds:
484+
stopping_conditions["ModelLatencyThresholds"] = [
485+
threshold.to_json for threshold in model_latency_thresholds
486+
]
487+
return stopping_conditions

0 commit comments

Comments
 (0)