Skip to content

fix: allow predictor to be returned from AutoML.deploy() #1220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/sagemaker/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def deploy(
vpc_config=None,
enable_network_isolation=False,
model_kms_key=None,
predictor_cls=None,
):
"""Deploy a candidate to a SageMaker Inference Pipeline and return a Predictor

Expand Down Expand Up @@ -237,10 +238,15 @@ def deploy(
training cluster for distributed training. Default: False
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.

Returns:
callable[string, sagemaker.session.Session]: Invocation of
``self.predictor_cls`` on the created endpoint name.
callable[string, sagemaker.session.Session] or ``None``:
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
the created endpoint name. Otherwise, ``None``.
"""
if candidate is None:
candidate_dict = self.best_candidate()
Expand All @@ -264,6 +270,7 @@ def deploy(
vpc_config=vpc_config,
enable_network_isolation=enable_network_isolation,
model_kms_key=model_kms_key,
predictor_cls=predictor_cls,
)

def _check_problem_type_and_job_objective(self, problem_type, job_objective):
Expand Down Expand Up @@ -299,6 +306,7 @@ def _deploy_inference_pipeline(
vpc_config=None,
enable_network_isolation=False,
model_kms_key=None,
predictor_cls=None,
):
"""Deploy a SageMaker Inference Pipeline.

Expand Down Expand Up @@ -329,6 +337,10 @@ def _deploy_inference_pipeline(
contains "SecurityGroupIds", "Subnets"
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.
"""
# construct Model objects
models = []
Expand All @@ -352,6 +364,7 @@ def _deploy_inference_pipeline(
pipeline = PipelineModel(
models=models,
role=self.role,
predictor_cls=predictor_cls,
name=name,
vpc_config=vpc_config,
sagemaker_session=sagemaker_session or self.sagemaker_session,
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/sagemaker/automl/test_auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
from mock import Mock, patch
from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator
from sagemaker.predictor import RealTimePredictor

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
Expand Down Expand Up @@ -472,6 +473,46 @@ def test_deploy(sagemaker_session, candidate_mock):
vpc_config=None,
enable_network_isolation=False,
model_kms_key=None,
predictor_cls=None,
)


def test_deploy_optional_args(sagemaker_session, candidate_mock):
auto_ml = AutoML(
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
)
auto_ml.best_candidate = Mock(name="best_candidate", return_value=CANDIDATE_DICT)
auto_ml._deploy_inference_pipeline = Mock("_deploy_inference_pipeline", return_value=None)

auto_ml.deploy(
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
sagemaker_session=sagemaker_session,
name=JOB_NAME,
endpoint_name=JOB_NAME,
tags=TAGS,
wait=False,
update_endpoint=True,
vpc_config=VPC_CONFIG,
enable_network_isolation=True,
model_kms_key=OUTPUT_KMS_KEY,
predictor_cls=RealTimePredictor,
)
auto_ml._deploy_inference_pipeline.assert_called_once()
auto_ml._deploy_inference_pipeline.assert_called_with(
candidate_mock.containers,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
name=JOB_NAME,
sagemaker_session=sagemaker_session,
endpoint_name=JOB_NAME,
tags=TAGS,
wait=False,
update_endpoint=True,
vpc_config=VPC_CONFIG,
enable_network_isolation=True,
model_kms_key=OUTPUT_KMS_KEY,
predictor_cls=RealTimePredictor,
)


Expand Down