Skip to content

fix: jumpstart async inference config predictor support #3970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def deploy(
)
Copy link
Contributor

@akrishna1995 akrishna1995 Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the link in the description to our ticket corp ( link this bug instead: #3969)


# If no predictor class was passed, add defaults to predictor
if self.orig_predictor_cls is None:
if self.orig_predictor_cls is None and async_inference_config is None:
return get_default_predictor(
predictor=predictor,
model_id=self.model_id,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def deploy(
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())

# If no predictor class was passed, add defaults to predictor
if self.orig_predictor_cls is None:
if self.orig_predictor_cls is None and async_inference_config is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have enough context here, but can you confirm that we will not see similar issues for other types of configs as well - serverless_inference_config, data_capture_config ? i.e. if that is None, do we need to return a defaultPredictor as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I am really not sure. I wanted to resolve the customer ticket ASAP, but I will do a deep dive into other kinds of issues with predictor being set under the hood.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the callout @akrishna1995.

IIUC, Data capture wouldn't impact the predictor, but could you please let us know if you disagree.

As of today, JumpStart does not claim to support serverless inference for any model. It might work for some of our smaller models, but for foundation models it wouldn't: they all run on GPU. We will look into it however for our older models, but can this be in a separate PRs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - Thanks for the confirmation, rest is LGTM, will approve and merge the PR

return get_default_predictor(
predictor=predictor,
model_id=self.model_id,
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from inspect import signature

import pytest
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig

from sagemaker.debugger.profiler_config import ProfilerConfig
from sagemaker.estimator import Estimator
Expand Down Expand Up @@ -640,6 +641,56 @@ def test_no_predictor_returns_default_predictor(
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor_with_presets)

@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_no_predictor_yes_async_inference_config(
self,
mock_estimator_deploy: mock.Mock,
mock_estimator_fit: mock.Mock,
mock_estimator_init: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session_estimator: mock.Mock,
mock_session_model: mock.Mock,
mock_is_valid_model_id: mock.Mock,
mock_get_default_predictor: mock.Mock,
):
mock_estimator_deploy.return_value = default_predictor

mock_get_default_predictor.return_value = default_predictor_with_presets

mock_is_valid_model_id.return_value = True

model_id, _ = "js-trainable-model-prepacked", "*"

mock_get_model_specs.side_effect = get_special_model_spec

mock_session_estimator.return_value = sagemaker_session
mock_session_model.return_value = sagemaker_session

estimator = JumpStartEstimator(
model_id=model_id,
)

channels = {
"training": f"s3://{get_jumpstart_content_bucket(region)}/"
f"some-training-dataset-doesn't-matter",
}

estimator.fit(channels)

predictor = estimator.deploy(async_inference_config=AsyncInferenceConfig())

mock_get_default_predictor.assert_not_called()
self.assertEqual(type(predictor), Predictor)

@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock
import unittest
import pytest
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.jumpstart.enums import JumpStartScriptScope

from sagemaker.jumpstart.model import JumpStartModel
Expand Down Expand Up @@ -417,6 +418,42 @@ def test_no_predictor_returns_default_predictor(
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor_with_presets)

@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_no_predictor_yes_async_inference_config(
self,
mock_model_deploy: mock.Mock,
mock_model_init: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_is_valid_model_id: mock.Mock,
mock_get_default_predictor: mock.Mock,
):
mock_get_default_predictor.return_value = default_predictor_with_presets

mock_model_deploy.return_value = default_predictor

mock_is_valid_model_id.return_value = True

model_id, _ = "js-model-class-model-prepacked", "*"

mock_get_model_specs.side_effect = get_special_model_spec

mock_session.return_value = sagemaker_session

model = JumpStartModel(
model_id=model_id,
)

model.deploy(async_inference_config=AsyncInferenceConfig())

mock_get_default_predictor.assert_not_called()

@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
Expand Down