Skip to content

Commit 7ed3e2d

Browse files
committed
fix: jumpstart async inference config predictor support
1 parent 0967a93 commit 7ed3e2d

File tree

4 files changed

+90
-2
lines changed

4 files changed

+90
-2
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ def deploy(
973973
)
974974

975975
# If no predictor class was passed, add defaults to predictor
976-
if self.orig_predictor_cls is None:
976+
if self.orig_predictor_cls is None and async_inference_config is None:
977977
return get_default_predictor(
978978
predictor=predictor,
979979
model_id=self.model_id,

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def deploy(
432432
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
433433

434434
# If no predictor class was passed, add defaults to predictor
435-
if self.orig_predictor_cls is None:
435+
if self.orig_predictor_cls is None and async_inference_config is None:
436436
return get_default_predictor(
437437
predictor=predictor,
438438
model_id=self.model_id,

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from inspect import signature
1919

2020
import pytest
21+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2122

2223
from sagemaker.debugger.profiler_config import ProfilerConfig
2324
from sagemaker.estimator import Estimator
@@ -640,6 +641,56 @@ def test_no_predictor_returns_default_predictor(
640641
self.assertEqual(type(predictor), Predictor)
641642
self.assertEqual(predictor, default_predictor_with_presets)
642643

644+
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
645+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
646+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
647+
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
648+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
649+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
650+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
651+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
652+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
653+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
654+
def test_no_predictor_yes_async_inference_config(
655+
self,
656+
mock_estimator_deploy: mock.Mock,
657+
mock_estimator_fit: mock.Mock,
658+
mock_estimator_init: mock.Mock,
659+
mock_get_model_specs: mock.Mock,
660+
mock_session_estimator: mock.Mock,
661+
mock_session_model: mock.Mock,
662+
mock_is_valid_model_id: mock.Mock,
663+
mock_get_default_predictor: mock.Mock,
664+
):
665+
mock_estimator_deploy.return_value = default_predictor
666+
667+
mock_get_default_predictor.return_value = default_predictor_with_presets
668+
669+
mock_is_valid_model_id.return_value = True
670+
671+
model_id, _ = "js-trainable-model-prepacked", "*"
672+
673+
mock_get_model_specs.side_effect = get_special_model_spec
674+
675+
mock_session_estimator.return_value = sagemaker_session
676+
mock_session_model.return_value = sagemaker_session
677+
678+
estimator = JumpStartEstimator(
679+
model_id=model_id,
680+
)
681+
682+
channels = {
683+
"training": f"s3://{get_jumpstart_content_bucket(region)}/"
684+
f"some-training-dataset-doesn't-matter",
685+
}
686+
687+
estimator.fit(channels)
688+
689+
predictor = estimator.deploy(async_inference_config=AsyncInferenceConfig())
690+
691+
mock_get_default_predictor.assert_not_called()
692+
self.assertEqual(type(predictor), Predictor)
693+
643694
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
644695
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
645696
@mock.patch("sagemaker.jumpstart.factory.model.Session")

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
import unittest
1818
import pytest
19+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
1920
from sagemaker.jumpstart.enums import JumpStartScriptScope
2021

2122
from sagemaker.jumpstart.model import JumpStartModel
@@ -417,6 +418,42 @@ def test_no_predictor_returns_default_predictor(
417418
self.assertEqual(type(predictor), Predictor)
418419
self.assertEqual(predictor, default_predictor_with_presets)
419420

421+
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
422+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
423+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
424+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
425+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
426+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
427+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
428+
def test_no_predictor_yes_async_inference_config(
429+
self,
430+
mock_model_deploy: mock.Mock,
431+
mock_model_init: mock.Mock,
432+
mock_get_model_specs: mock.Mock,
433+
mock_session: mock.Mock,
434+
mock_is_valid_model_id: mock.Mock,
435+
mock_get_default_predictor: mock.Mock,
436+
):
437+
mock_get_default_predictor.return_value = default_predictor_with_presets
438+
439+
mock_model_deploy.return_value = default_predictor
440+
441+
mock_is_valid_model_id.return_value = True
442+
443+
model_id, _ = "js-model-class-model-prepacked", "*"
444+
445+
mock_get_model_specs.side_effect = get_special_model_spec
446+
447+
mock_session.return_value = sagemaker_session
448+
449+
model = JumpStartModel(
450+
model_id=model_id,
451+
)
452+
453+
model.deploy(async_inference_config=AsyncInferenceConfig())
454+
455+
mock_get_default_predictor.assert_not_called()
456+
420457
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
421458
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
422459
@mock.patch("sagemaker.jumpstart.factory.model.Session")

0 commit comments

Comments
 (0)