Skip to content

feature: add inference_id to predict #2093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,27 @@ def invoke_endpoint(
CustomAttributes=None,
TargetModel=None,
TargetVariant=None,
InferenceId=None,
):
"""Invoke the endpoint.

Args:
Body:
EndpointName:
Accept: (Default value = None)
CustomAttributes: (Default value = None)
Body: Input data for which you want the model to provide inference.
EndpointName: The name of the endpoint that you specified when you
created the endpoint using the CreateEndpoint API.
ContentType: The MIME type of the input data in the request body (Default value = None)
Accept: The desired MIME type of the inference in the response (Default value = None)
CustomAttributes: Provides additional information about a request for an inference
submitted to a model hosted at an Amazon SageMaker endpoint (Default value = None)
TargetModel: The model to request for inference when invoking a multi-model endpoint
(Default value = None)
TargetVariant: Specify the production variant to send the inference request to when
invoking an endpoint that is running two or more variants (Default value = None)
InferenceId: If you provide a value, it is added to the captured data when you enable
data capture on the endpoint (Default value = None)
Comment on lines +434 to +446
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much!


Returns:

object: Inference for the given input.
"""
url = "http://localhost:%s/invocations" % self.serving_port
headers = {}
Expand All @@ -456,6 +466,9 @@ def invoke_endpoint(
if TargetVariant is not None:
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant

if InferenceId is not None:
headers["X-Amzn-SageMaker-Inference-Id"] = InferenceId

r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)

return {"Body": r, "ContentType": Accept}
Expand Down
21 changes: 16 additions & 5 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
self._model_names = self._get_model_names()
self._context = None

def predict(self, data, initial_args=None, target_model=None, target_variant=None):
def predict(
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
):
"""Return the inference from the specified endpoint.

Args:
Expand All @@ -111,8 +113,10 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
in case of a multi model endpoint. Does not apply to endpoints hosting
single model (Default: None)
target_variant (str): The name of the production variant to run an inference
request on (Default: None). Note that the ProductionVariant identifies the model
you want to host and the resources you want to deploy for hosting it.
request on (Default: None). Note that the ProductionVariant identifies the
model you want to host and the resources you want to deploy for hosting it.
inference_id (str): If you provide a value, it is added to the captured data
when you enable data capture on the endpoint (Default: None).

Returns:
object: Inference for the given input. If a deserializer was specified when creating
Expand All @@ -121,7 +125,9 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
as is.
"""

request_args = self._create_request_args(data, initial_args, target_model, target_variant)
request_args = self._create_request_args(
data, initial_args, target_model, target_variant, inference_id
)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

Expand All @@ -131,7 +137,9 @@ def _handle_response(self, response):
content_type = response.get("ContentType", "application/octet-stream")
return self.deserializer.deserialize(response_body, content_type)

def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
def _create_request_args(
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
):
"""Placeholder docstring"""
args = dict(initial_args) if initial_args else {}

Expand All @@ -150,6 +158,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
if target_variant:
args["TargetVariant"] = target_variant

if inference_id:
args["InferenceId"] = inference_id

data = self.serializer.serialize(data)

args["Body"] = data
Expand Down
86 changes: 86 additions & 0 deletions tests/integ/test_predict_with_inference_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import pytest

import tests.integ
import tests.integ.timeout

from sagemaker import image_uris
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer
from sagemaker.utils import unique_name_from_base

from tests.integ import DATA_DIR


ROLE = "SageMakerRole"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "ml.c5.xlarge"
TEST_CSV_DATA = "42,42,42,42,42,42,42"
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")


@pytest.yield_fixture(scope="module")
def endpoint_name(sagemaker_session):
endpoint_name = unique_name_from_base("model-inference-id-integ")
xgb_model_data = sagemaker_session.upload_data(
path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)

xgb_image = image_uris.retrieve(
"xgboost",
sagemaker_session.boto_region_name,
version="1",
image_scope="inference",
)

with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
endpoint_name=endpoint_name, sagemaker_session=sagemaker_session, hours=2
):
xgb_model = Model(
model_data=xgb_model_data,
image_uri=xgb_image,
name=endpoint_name, # model name
role=ROLE,
sagemaker_session=sagemaker_session,
)
xgb_model.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name)
yield endpoint_name


def test_predict_with_inference_id(sagemaker_session, endpoint_name):
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=CSVSerializer(),
)

# Validate that no exception is raised when the target_variant is specified.
response = predictor.predict(TEST_CSV_DATA, inference_id="foo")
assert response


def test_invoke_endpoint_with_inference_id(sagemaker_session, endpoint_name):
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
Body=TEST_CSV_DATA,
ContentType="text/csv",
Accept="text/csv",
InferenceId="foo",
)
assert response
24 changes: 24 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RETURN_VALUE = 0
CSV_RETURN_VALUE = "1,2,3\r\n"
PRODUCTION_VARIANT_1 = "PRODUCTION_VARIANT_1"
INFERENCE_ID = "inference-id"

ENDPOINT_DESC = {"EndpointArn": "foo", "EndpointConfigName": ENDPOINT}

Expand Down Expand Up @@ -98,6 +99,29 @@ def test_predict_call_with_target_variant():
assert result == RETURN_VALUE


def test_predict_call_with_inference_id():
sagemaker_session = empty_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session)

data = "untouched"
result = predictor.predict(data, inference_id=INFERENCE_ID)

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

expected_request_args = {
"Accept": DEFAULT_ACCEPT,
"Body": data,
"ContentType": DEFAULT_CONTENT_TYPE,
"EndpointName": ENDPOINT,
"InferenceId": INFERENCE_ID,
}

call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args

assert result == RETURN_VALUE


def test_multi_model_predict_call():
sagemaker_session = empty_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session)
Expand Down