Skip to content

Commit 28f41b0

Browse files
committed
Add LambdaModel and LambdaPredictor
1 parent 95ef81f commit 28f41b0

File tree

3 files changed

+186
-0
lines changed

3 files changed

+186
-0
lines changed

src/sagemaker/serverless/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for performing machine learning on serverless compute."""
14+
from sagemaker.serverless.model import LambdaModel # noqa: F401
15+
from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401

src/sagemaker/serverless/model.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Models that can be deployed to serverless compute."""
14+
from __future__ import absolute_import
15+
16+
import time
17+
from typing import Optional
18+
19+
import boto3
20+
import botocore
21+
22+
from sagemaker.model import ModelBase
23+
24+
from .predictor import LambdaPredictor
25+
26+
27+
class LambdaModel(ModelBase):
28+
"""A model that can be deployed to Lambda."""
29+
30+
def __init__(
31+
self, image_uri: str, role: str, client: Optional[botocore.client.BaseClient] = None
32+
) -> None:
33+
"""Initialize instance attributes.
34+
35+
Arguments:
36+
image_uri: URI of a container image in the Amazon ECR registry. The image
37+
should contain a handler that performs inference.
38+
role: The Amazon Resource Name (ARN) of the IAM role that Lambda will assume
39+
when it performs inference
40+
client: The Lambda client used to interact with Lambda.
41+
"""
42+
self._client = client or boto3.client("lambda")
43+
self._image_uri = image_uri
44+
self._role = role
45+
46+
def deploy(
47+
self, function_name: str, timeout: int, memory_size: int, wait: bool = False
48+
) -> LambdaPredictor:
49+
"""Create a Lambda function using the image specified in the constructor.
50+
51+
Arguments:
52+
function_name: The name of the function.
53+
timeout: The number of seconds that the function can run for before being terminated.
54+
memory_size: The amount of memory in MB that the function has access to.
55+
wait: If true, wait until the deployment completes (default: True).
56+
57+
Returns:
58+
A LambdaPredictor instance that performs inference using the specified image.
59+
"""
60+
response = self._client.create_function(
61+
FunctionName=function_name,
62+
PackageType="Image",
63+
Role=self._role,
64+
Code={
65+
"ImageUri": self._image_uri,
66+
},
67+
Timeout=timeout,
68+
MemorySize=memory_size,
69+
)
70+
71+
if not wait:
72+
return LambdaPredictor(function_name, client=self._client)
73+
74+
# Poll function state.
75+
polling_interval = 5
76+
while response["State"] == "Pending":
77+
time.sleep(polling_interval)
78+
response = self._client.get_function_configuration(FunctionName=function_name)
79+
80+
if response["State"] != "Active":
81+
raise RuntimeError("Failed to deploy model to Lambda: %s" % response["StateReason"])
82+
83+
return LambdaPredictor(function_name, client=self._client)
84+
85+
def destroy(self) -> None:
86+
"""Destroy resources associated with this model.
87+
88+
This method does not delete the image specified in the constructor. As
89+
a result, this method is a no-op.
90+
"""

src/sagemaker/serverless/predictor.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Predictors that are hosted on serverless compute."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Tuple
17+
18+
import boto3
19+
import botocore
20+
21+
from sagemaker import deserializers, serializers
22+
from sagemaker.predictor import PredictorBase
23+
24+
25+
class LambdaPredictor(PredictorBase):
26+
"""A deployed model hosted on Lambda."""
27+
28+
def __init__(
29+
self, function_name: str, client: Optional[botocore.client.BaseClient] = None
30+
) -> None:
31+
"""Initialize instance attributes.
32+
33+
Arguments:
34+
function_name: The name of the function.
35+
client: The Lambda client used to interact with Lambda.
36+
"""
37+
self._client = client or boto3.client("lambda")
38+
self._function_name = function_name
39+
self._serializer = serializers.JSONSerializer()
40+
self._deserializer = deserializers.JSONDeserializer()
41+
42+
def predict(self, data: dict) -> dict:
43+
"""Invoke the Lambda function specified in the constructor.
44+
45+
This function is synchronous. It will only return after the function
46+
has produced a prediction.
47+
48+
Arguments:
49+
data: The data sent to the Lambda function as input.
50+
51+
Returns:
52+
The data returned by the Lambda function.
53+
"""
54+
response = self._client.invoke(
55+
FunctionName=self._function_name,
56+
InvocationType="RequestResponse",
57+
Payload=self._serializer.serialize(data),
58+
)
59+
return self._deserializer.deserialize(
60+
response["Payload"],
61+
response["ResponseMetadata"]["HTTPHeaders"]["content-type"],
62+
)
63+
64+
def destroy(self) -> None:
65+
"""Destroy the Lambda function specified in the constructor."""
66+
self._client.delete_function(FunctionName=self._function_name)
67+
68+
@property
69+
def content_type(self) -> str:
70+
"""The MIME type of the data sent to the Lambda function."""
71+
return self._serializer.CONTENT_TYPE
72+
73+
@property
74+
def accept(self) -> Tuple[str]:
75+
"""The content type(s) that are expected from the Lambda function."""
76+
return self._deserializer.ACCEPT
77+
78+
@property
79+
def function_name(self) -> str:
80+
"""The name of the Lambda function this predictor invokes."""
81+
return self._function_name

0 commit comments

Comments
 (0)